MLIR 23.0.0git
VectorOps.cpp
Go to the documentation of this file.
1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
15
27#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
33#include "mlir/IR/IRMapping.h"
37#include "mlir/IR/ValueRange.h"
40#include "mlir/Support/LLVM.h"
42#include "llvm/ADT/ArrayRef.h"
43#include "llvm/ADT/STLExtras.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/StringSet.h"
47#include "llvm/ADT/TypeSwitch.h"
48#include "llvm/Support/Casting.h"
49
50#include <cassert>
51#include <cstdint>
52#include <numeric>
53
54#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
55// Pull in all enum type and utility function definitions.
56#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
57
58using namespace mlir;
59using namespace mlir::vector;
60
61/// Helper enum to classify mask value.
62enum class MaskFormat {
66};
67
68/// Helper method to classify a mask value. Currently, the method
69/// looks "under the hood" of a constant value with dense attributes
70/// and a constant mask operation (since the client may be called at
71/// various stages during progressive lowering).
73 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
74 // Inspect constant dense values. We count up for bits that
75 // are set, count down for bits that are cleared, and bail
76 // when a mix is detected.
77 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
78 int64_t val = 0;
79 for (bool b : denseElts.getValues<bool>())
80 if (b && val >= 0)
81 val++;
82 else if (!b && val <= 0)
83 val--;
84 else
86 if (val > 0)
88 if (val < 0)
90 }
91 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
92 // Inspect constant mask index. If the index exceeds the
93 // dimension size, all bits are set. If the index is zero
94 // or less, no bits are set.
95 ArrayRef<int64_t> masks = m.getMaskDimSizes();
96 auto shape = m.getType().getShape();
97 bool allTrue = true;
98 bool allFalse = true;
99 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
100 if (maskIdx < dimSize)
101 allTrue = false;
102 if (maskIdx > 0)
103 allFalse = false;
104 }
105 if (allTrue)
106 return MaskFormat::AllTrue;
107 if (allFalse)
109 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
110 // Finds all-false create_masks. An all-true create_mask requires all
111 // dims to be constants, so that'll be folded to a constant_mask, then
112 // detected in the constant_mask case.
113 auto maskOperands = m.getOperands();
114 for (Value operand : maskOperands) {
115 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
116 int64_t dimSize =
117 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
118 if (dimSize <= 0)
120 }
121 }
122 return MaskFormat::Unknown;
123 }
124 return MaskFormat::Unknown;
125}
126
127/// Default callback to build a region with a 'vector.yield' terminator with no
128/// arguments.
130 vector::YieldOp::create(builder, loc);
131}
132
133// Helper for verifying combining kinds in contractions and reductions.
134static bool isSupportedCombiningKind(CombiningKind combiningKind,
135 Type elementType) {
136 switch (combiningKind) {
137 case CombiningKind::ADD:
138 case CombiningKind::MUL:
139 return elementType.isIntOrIndexOrFloat();
140 case CombiningKind::MINUI:
141 case CombiningKind::MINSI:
142 case CombiningKind::MAXUI:
143 case CombiningKind::MAXSI:
144 case CombiningKind::AND:
145 case CombiningKind::OR:
146 case CombiningKind::XOR:
147 return elementType.isIntOrIndex();
148 case CombiningKind::MINNUMF:
149 case CombiningKind::MAXNUMF:
150 case CombiningKind::MINIMUMF:
151 case CombiningKind::MAXIMUMF:
152 return llvm::isa<FloatType>(elementType);
153 }
154 return false;
155}
156
157/// Returns the effective rank of the vector to read/write for Xfer Ops
158///
159/// When the element type of the shaped type is _a scalar_, this will simply
160/// return the rank of the vector ( the result for xfer_read or the value to
161/// store for xfer_write).
162///
163/// When the element type of the base shaped type is _a vector_, returns the
164/// difference between the original vector type and the element type of the
165/// shaped type.
166///
167/// EXAMPLE 1 (element type is _a scalar_):
168/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
169/// - shapedType.getElementType() = f32 (rank 0)
170/// - vectorType.getRank() = 2
171/// - Result = 2 - 0 = 2
172///
173/// EXAMPLE 2 (element type is _a vector_):
174/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
175/// - shapedType.getElementType() = vector<20xf32> (rank 1)
176/// - vectorType.getRank() = 1
177/// - Result = 1 - 1 = 0
178///
179/// This is used to determine the number of minor dimensions for identity maps
180/// in vector transfer Ops.
181static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
182 VectorType vectorType) {
183 unsigned elementVectorRank = 0;
184 VectorType elementVectorType =
185 llvm::dyn_cast<VectorType>(shapedType.getElementType());
186 if (elementVectorType)
187 elementVectorRank += elementVectorType.getRank();
188 return vectorType.getRank() - elementVectorRank;
189}
190
192 VectorType vectorType) {
193 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
194 // TODO: replace once we have 0-d vectors.
195 if (shapedType.getRank() == 0 &&
196 vectorType.getShape() == ArrayRef<int64_t>{1})
197 return AffineMap::get(
198 /*numDims=*/0, /*numSymbols=*/0,
199 getAffineConstantExpr(0, shapedType.getContext()));
201 shapedType.getRank(),
202 getEffectiveVectorRankForXferOp(shapedType, vectorType),
203 shapedType.getContext());
204}
205
206/// Check if `write` is of a constant splat and the masked `read` is padded with
207/// the same splat value -- meaning it could be the same value as the initial
208/// constant splat.
209static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
210 vector::TransferReadOp read) {
211 auto readMask = read.getMask();
212 auto writeMask = write.getMask();
213 // Check if the masks are consistent. The splat value could be the same if the
214 // read is masked (and padded with the splat value), and the write is unmasked
215 // or has the same mask. Note this does not allow the case where the write is
216 // masked and the read is unmasked, as then the read could be of more elements
217 // than the write (which may not be the same value).
218 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
219 if (!couldBeSameSplat)
220 return false;
221 // Check for constant splat (as the source of the write).
222 DenseElementsAttr splatAttr;
223 if (!matchPattern(write.getVector(),
224 m_Constant<DenseElementsAttr>(&splatAttr)) ||
225 !splatAttr.isSplat()) {
226 return false;
227 }
228 // The padding of the read and the constant splat value must be the same.
229 Attribute padAttr;
230 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
231 return false;
232 return padAttr == splatAttr.getSplatValue<Attribute>();
233}
234
235bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
236 vector::TransferReadOp read) {
237 return !defWrite.hasOutOfBoundsDim() &&
238 defWrite.getIndices() == read.getIndices() &&
239 defWrite.getVectorType() == read.getVectorType() &&
240 defWrite.getPermutationMap() == read.getPermutationMap() &&
241 ((!defWrite.getMask() && !read.getMask()) ||
243}
244
245bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
246 vector::TransferWriteOp priorWrite) {
247 return priorWrite.getIndices() == write.getIndices() &&
248 priorWrite.getMask() == write.getMask() &&
249 priorWrite.getVectorType() == write.getVectorType() &&
250 priorWrite.getPermutationMap() == write.getPermutationMap();
251}
252
254 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
255 bool testDynamicValueUsingBounds) {
256 // For simplicity only look at transfer of same type.
257 if (transferA.getVectorType() != transferB.getVectorType())
258 return false;
259 unsigned rankOffset = transferA.getLeadingShapedRank();
260 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
261 Value indexA = transferA.getIndices()[i];
262 Value indexB = transferB.getIndices()[i];
263 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
264 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
265
266 if (i < rankOffset) {
267 // For leading dimensions, if we can prove that index are different we
268 // know we are accessing disjoint slices.
269 if (cstIndexA.has_value() && cstIndexB.has_value()) {
270 if (*cstIndexA != *cstIndexB)
271 return true;
272 continue;
273 }
274 if (testDynamicValueUsingBounds) {
275 // First try to see if we can fully compose and simplify the affine
276 // expression as a fast track.
277 FailureOr<uint64_t> delta =
279 if (succeeded(delta) && *delta != 0)
280 return true;
281
282 FailureOr<bool> testEqual =
284 if (succeeded(testEqual) && !testEqual.value())
285 return true;
286 }
287 } else {
288 // For this dimension, we slice a part of the memref we need to make sure
289 // the intervals accessed don't overlap.
290 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
291 if (cstIndexA.has_value() && cstIndexB.has_value()) {
292 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
293 if (distance >= vectorDim)
294 return true;
295 continue;
296 }
297 if (testDynamicValueUsingBounds) {
298 // First try to see if we can fully compose and simplify the affine
299 // expression as a fast track.
300 FailureOr<int64_t> delta =
302 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
303 return true;
304
305 FailureOr<int64_t> computeDelta =
307 if (succeeded(computeDelta)) {
308 if (std::abs(computeDelta.value()) >= vectorDim)
309 return true;
310 }
311 }
312 }
313 }
314 return false;
315}
316
317bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
318 VectorTransferOpInterface transferB,
319 bool testDynamicValueUsingBounds) {
320 if (transferA.getBase() != transferB.getBase())
321 return false;
322 return isDisjointTransferIndices(transferA, transferB,
323 testDynamicValueUsingBounds);
324}
325
326// Helper to iterate over n-D vector slice elements. Calculate the next
327// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
328// Modifies the `position` in place. Returns a failure when `position` becomes
329// the end position.
330static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
332 ArrayRef<int64_t> offsets) {
333 for (auto [posInDim, dimSize, offsetInDim] :
334 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
335 ++posInDim;
336 if (posInDim < dimSize + offsetInDim)
337 return success();
338
339 // Carry the overflow to the next loop iteration.
340 posInDim = offsetInDim;
341 }
342
343 return failure();
344}
345
346/// Returns the integer numbers in `values`. `values` are expected to be
347/// constant operations.
350 llvm::transform(values, std::back_inserter(ints), [](Value value) {
351 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
352 assert(constOp && "Unexpected non-constant index");
353 return constOp.value();
354 });
355 return ints;
356}
357
358/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
359/// be constant operations.
362 llvm::transform(
363 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
364 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
365 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
366 });
367 return ints;
368}
369
370/// Convert `foldResults` into Values. Integer attributes are converted to
371/// constant op.
373 ArrayRef<OpFoldResult> foldResults) {
374 SmallVector<Value> values;
375 llvm::transform(foldResults, std::back_inserter(values),
376 [&](OpFoldResult foldResult) {
377 if (auto attr = dyn_cast<Attribute>(foldResult))
379 builder, loc, cast<IntegerAttr>(attr).getInt())
380 .getResult();
381
382 return cast<Value>(foldResult);
383 });
384 return values;
385}
386
387std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
388 if (value.getDefiningOp<vector::VectorScaleOp>())
389 return 1;
390 auto mul = value.getDefiningOp<arith::MulIOp>();
391 if (!mul)
392 return {};
393 auto lhs = mul.getLhs();
394 auto rhs = mul.getRhs();
395 if (lhs.getDefiningOp<vector::VectorScaleOp>())
396 return getConstantIntValue(rhs);
397 if (rhs.getDefiningOp<vector::VectorScaleOp>())
398 return getConstantIntValue(lhs);
399 return {};
400}
401
402/// Converts numeric attributes to the expected type. Supports
403/// integer-to-integer and float-to-integer conversions. Returns the original
404/// attribute if no conversion is needed or supported.
405static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
406 // Integer-to-integer conversion
407 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
408 if (auto intType = dyn_cast<IntegerType>(expectedType)) {
409 if (intAttr.getType() != expectedType)
410 return IntegerAttr::get(expectedType, intAttr.getInt());
411 }
412 return attr;
413 }
414
415 // Float-to-integer bitcast (preserves bit representation)
416 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
417 auto intType = dyn_cast<IntegerType>(expectedType);
418 if (!intType)
419 return attr;
420
421 APFloat floatVal = floatAttr.getValue();
422 APInt intVal = floatVal.bitcastToAPInt();
423 return IntegerAttr::get(expectedType, intVal);
424 }
425
426 return attr;
427}
428
429//===----------------------------------------------------------------------===//
430// CombiningKindAttr
431//===----------------------------------------------------------------------===//
432
433namespace mlir {
434namespace vector {
435namespace detail {
437 using KeyTy = uint64_t;
438
440
441 bool operator==(const KeyTy &key) const { return value == key; }
442
444 const KeyTy &key) {
445 return new (allocator.allocate<BitmaskEnumStorage>())
447 }
448
450};
451} // namespace detail
452} // namespace vector
453} // namespace mlir
454
455//===----------------------------------------------------------------------===//
456// VectorDialect
457//===----------------------------------------------------------------------===//
458
459namespace {
460/// This class defines the interface for handling inlining with vector dialect
461/// operations.
462struct VectorInlinerInterface : public DialectInlinerInterface {
463 using DialectInlinerInterface::DialectInlinerInterface;
464
465 /// All vector dialect ops can be inlined.
466 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
467 return true;
468 }
469};
470} // namespace
471
472void VectorDialect::initialize() {
473 addAttributes<
474#define GET_ATTRDEF_LIST
475#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
476 >();
477
478 addOperations<
479#define GET_OP_LIST
480#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
481 >();
482
483 addInterfaces<VectorInlinerInterface>();
484
485 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
486 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
487 YieldOp>();
488 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
489 TransferWriteOp>();
490 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
491 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
492 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
493}
494
495/// Materialize a single constant operation from a given attribute value with
496/// the desired resultant type.
497Operation *VectorDialect::materializeConstant(OpBuilder &builder,
498 Attribute value, Type type,
499 Location loc) {
500 if (isa<ub::PoisonAttrInterface>(value))
501 return value.getDialect().materializeConstant(builder, value, type, loc);
502
503 return arith::ConstantOp::materialize(builder, value, type, loc);
504}
505
507 return builder.getIntegerType(64);
508}
509
511 ArrayRef<int64_t> values) {
512 return builder.getI64ArrayAttr(values);
513}
514
515//===----------------------------------------------------------------------===//
516// MultiDimReductionOp
517//===----------------------------------------------------------------------===//
518
519void vector::MultiDimReductionOp::build(OpBuilder &builder,
520 OperationState &result, Value source,
521 Value acc, ArrayRef<bool> reductionMask,
522 CombiningKind kind) {
523 SmallVector<int64_t> reductionDims;
524 for (const auto &en : llvm::enumerate(reductionMask))
525 if (en.value())
526 reductionDims.push_back(en.index());
527 build(builder, result, kind, source, acc, reductionDims);
528}
529
530OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
531 // Single parallel dim, this is a noop.
532 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
533 return getSource();
534 return {};
535}
536
537std::optional<SmallVector<int64_t, 4>>
538MultiDimReductionOp::getShapeForUnroll() {
539 return llvm::to_vector<4>(getSourceVectorType().getShape());
540}
541
542LogicalResult MultiDimReductionOp::verify() {
543 SmallVector<int64_t> targetShape;
544 SmallVector<bool> scalableDims;
545 Type inferredReturnType;
546 auto sourceScalableDims = getSourceVectorType().getScalableDims();
547 for (auto [dimIdx, dimSize] :
548 llvm::enumerate(getSourceVectorType().getShape()))
549 if (!llvm::any_of(getReductionDims(),
550 [dimIdx = dimIdx](int64_t reductionDimIdx) {
551 return reductionDimIdx == static_cast<int64_t>(dimIdx);
552 })) {
553 targetShape.push_back(dimSize);
554 scalableDims.push_back(sourceScalableDims[dimIdx]);
555 }
556 // TODO: update to also allow 0-d vectors when available.
557 if (targetShape.empty())
558 inferredReturnType = getSourceVectorType().getElementType();
559 else
560 inferredReturnType = VectorType::get(
561 targetShape, getSourceVectorType().getElementType(), scalableDims);
562 if (getType() != inferredReturnType)
563 return emitOpError() << "destination type " << getType()
564 << " is incompatible with source type "
565 << getSourceVectorType();
566
567 return success();
568}
569
570/// Returns the mask type expected by this operation.
571Type MultiDimReductionOp::getExpectedMaskType() {
572 auto vecType = getSourceVectorType();
573 return VectorType::get(vecType.getShape(),
574 IntegerType::get(vecType.getContext(), /*width=*/1),
575 vecType.getScalableDims());
576}
577
578namespace {
579// Only unit dimensions that are being reduced are folded. If the dimension is
580// unit, but not reduced, it is not folded, thereby keeping the output type the
581// same. If not all dimensions which are reduced are of unit dimension, this
582// transformation does nothing. This is just a generalization of
583// ElideSingleElementReduction for ReduceOp.
584struct ElideUnitDimsInMultiDimReduction
585 : public OpRewritePattern<MultiDimReductionOp> {
586 using Base::Base;
587
588 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
589 PatternRewriter &rewriter) const override {
590 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
591 for (const auto &dim : enumerate(shape)) {
592 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
593 return failure();
594 }
595
596 // Vector mask setup.
597 OpBuilder::InsertionGuard guard(rewriter);
598 Operation *rootOp;
599 Value mask;
600 if (reductionOp.isMasked()) {
601 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
602 rootOp = reductionOp.getMaskingOp();
603 mask = reductionOp.getMaskingOp().getMask();
604 } else {
605 rootOp = reductionOp;
606 }
607
608 Location loc = reductionOp.getLoc();
609 Value acc = reductionOp.getAcc();
610 Value cast;
611 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
612 if (mask) {
613 VectorType newMaskType =
614 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
615 dstVecType.getScalableDims());
616 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
617 }
618 cast = vector::ShapeCastOp::create(
619 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
620 } else {
621 // This means we are reducing all the dimensions, and all reduction
622 // dimensions are of size 1. So a simple extraction would do.
623 if (mask)
624 mask = vector::ExtractOp::create(rewriter, loc, mask);
625 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
626 }
627
628 Value result =
629 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
630 cast, /*fastmath=*/nullptr, mask);
631 rewriter.replaceOp(rootOp, result);
632 return success();
633 }
634};
635} // namespace
636
637void MultiDimReductionOp::getCanonicalizationPatterns(
638 RewritePatternSet &results, MLIRContext *context) {
639 results.add<ElideUnitDimsInMultiDimReduction>(context);
640}
641
642//===----------------------------------------------------------------------===//
643// ReductionOp
644//===----------------------------------------------------------------------===//
645
646void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
647 CombiningKind kind, Value vector,
648 arith::FastMathFlags fastMathFlags) {
649 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
650}
651
652void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
653 CombiningKind kind, Value vector, Value acc,
654 arith::FastMathFlags fastMathFlags) {
655 build(builder, result,
656 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
657 acc, fastMathFlags);
658}
659
660LogicalResult ReductionOp::verify() {
661 // Verify for 0-D and 1-D vector.
662 int64_t rank = getSourceVectorType().getRank();
663 if (rank > 1)
664 return emitOpError("unsupported reduction rank: ") << rank;
665
666 // Verify supported reduction kind.
667 Type eltType = getDest().getType();
668 if (!isSupportedCombiningKind(getKind(), eltType))
669 return emitOpError("unsupported reduction type '")
670 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
671 << "'";
672
673 return success();
674}
675
676// MaskableOpInterface methods.
677
678/// Returns the mask type expected by this operation.
679Type ReductionOp::getExpectedMaskType() {
680 auto vecType = getSourceVectorType();
681 return VectorType::get(vecType.getShape(),
682 IntegerType::get(vecType.getContext(), /*width=*/1),
683 vecType.getScalableDims());
684}
685
687 OpBuilder &builder, Location loc,
688 Value vector) {
689 switch (op) {
690 case arith::AtomicRMWKind::addf:
691 case arith::AtomicRMWKind::addi:
692 return vector::ReductionOp::create(builder, vector.getLoc(),
693 CombiningKind::ADD, vector);
694 case arith::AtomicRMWKind::mulf:
695 case arith::AtomicRMWKind::muli:
696 return vector::ReductionOp::create(builder, vector.getLoc(),
697 CombiningKind::MUL, vector);
698 case arith::AtomicRMWKind::minimumf:
699 return vector::ReductionOp::create(builder, vector.getLoc(),
700 CombiningKind::MINIMUMF, vector);
701 case arith::AtomicRMWKind::mins:
702 return vector::ReductionOp::create(builder, vector.getLoc(),
703 CombiningKind::MINSI, vector);
704 case arith::AtomicRMWKind::minu:
705 return vector::ReductionOp::create(builder, vector.getLoc(),
706 CombiningKind::MINUI, vector);
707 case arith::AtomicRMWKind::maximumf:
708 return vector::ReductionOp::create(builder, vector.getLoc(),
709 CombiningKind::MAXIMUMF, vector);
710 case arith::AtomicRMWKind::maxs:
711 return vector::ReductionOp::create(builder, vector.getLoc(),
712 CombiningKind::MAXSI, vector);
713 case arith::AtomicRMWKind::maxu:
714 return vector::ReductionOp::create(builder, vector.getLoc(),
715 CombiningKind::MAXUI, vector);
716 case arith::AtomicRMWKind::andi:
717 return vector::ReductionOp::create(builder, vector.getLoc(),
718 CombiningKind::AND, vector);
719 case arith::AtomicRMWKind::ori:
720 return vector::ReductionOp::create(builder, vector.getLoc(),
721 CombiningKind::OR, vector);
722 case arith::AtomicRMWKind::minnumf:
723 return vector::ReductionOp::create(builder, vector.getLoc(),
724 CombiningKind::MINNUMF, vector);
725 case arith::AtomicRMWKind::maxnumf:
726 return vector::ReductionOp::create(builder, vector.getLoc(),
727 CombiningKind::MAXNUMF, vector);
728 case arith::AtomicRMWKind::xori:
729 return vector::ReductionOp::create(builder, vector.getLoc(),
730 CombiningKind::XOR, vector);
731 default:
732 (void)emitOptionalError(loc, "Reduction operation type not supported");
733 break;
734 }
735 return nullptr;
736}
737
738std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
739 return llvm::to_vector<4>(getSourceVectorType().getShape());
740}
741
742namespace {
743struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
744 using Base::Base;
745
746 LogicalResult matchAndRewrite(ReductionOp reductionOp,
747 PatternRewriter &rewriter) const override {
748 // Vector mask setup.
749 OpBuilder::InsertionGuard guard(rewriter);
750 auto maskableOp =
751 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
752 Operation *rootOp;
753 Value mask;
754 if (maskableOp.isMasked()) {
755 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
756 rootOp = maskableOp.getMaskingOp();
757 mask = maskableOp.getMaskingOp().getMask();
758 } else {
759 rootOp = reductionOp;
760 }
761
762 auto vectorType = reductionOp.getSourceVectorType();
763 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
764 return failure();
765
766 Location loc = reductionOp.getLoc();
767 if (mask)
768 mask = ExtractOp::create(rewriter, loc, mask);
769 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
770
771 if (Value acc = reductionOp.getAcc())
772 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
773 result, acc,
774 reductionOp.getFastmathAttr(), mask);
775
776 rewriter.replaceOp(rootOp, result);
777 return success();
778 }
779};
780} // namespace
781
782void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
783 MLIRContext *context) {
784 results.add<ElideSingleElementReduction>(context);
785}
786
787//===----------------------------------------------------------------------===//
788// ContractionOp
789//===----------------------------------------------------------------------===//
790
791void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
793 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
794 ArrayRef<IteratorType> iteratorTypes) {
795 result.addOperands({lhs, rhs, acc});
796 result.addTypes(acc.getType());
797 result.addAttribute(
798 getIndexingMapsAttrName(result.name),
799 builder.getAffineMapArrayAttr(
800 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
801 result.addAttribute(
802 getIteratorTypesAttrName(result.name),
803 builder.getArrayAttr(llvm::map_to_vector(
804 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
805 return IteratorTypeAttr::get(builder.getContext(), t);
806 })));
807}
808
809void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
811 ArrayAttr indexingMaps,
812 ArrayAttr iteratorTypes) {
813 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
814 ContractionOp::getDefaultKind());
815}
816
817void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
819 ArrayAttr indexingMaps,
820 ArrayAttr iteratorTypes, CombiningKind kind) {
821 result.addOperands({lhs, rhs, acc});
822 result.addTypes(acc.getType());
823 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
824 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
825 result.addAttribute(getKindAttrName(result.name),
826 CombiningKindAttr::get(builder.getContext(), kind));
827}
828
829ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
835 Type resultType;
836 auto loc = parser.getCurrentLocation();
837 DictionaryAttr dictAttr;
838 // TODO: Unify linalg op attribute parsing.
839 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
840 parser.parseComma() || parser.parseOperand(rhsInfo) ||
841 parser.parseComma() || parser.parseOperand(accInfo) ||
842 parser.parseTrailingOperandList(masksInfo) ||
843 parser.parseOptionalAttrDict(result.attributes) ||
844 parser.parseColonTypeList(types) ||
845 parser.parseKeywordType("into", resultType) ||
846 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
847 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
848 parser.resolveOperand(accInfo, resultType, result.operands) ||
849 parser.addTypeToList(resultType, result.types))
850 return failure();
851 result.attributes.append(dictAttr.getValue().begin(),
852 dictAttr.getValue().end());
853
854 // Convert array of string into an array of IteratyType enums. This is needed,
855 // because tests still use the old format when 'iterator_types' attribute is
856 // represented as an array of strings.
857 // TODO: Remove this conversion once tests are fixed.
858 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
859 result.attributes.get(getIteratorTypesAttrName(result.name)));
860 if (!iteratorTypes) {
861 return parser.emitError(loc)
862 << "expected " << getIteratorTypesAttrName(result.name)
863 << " array attribute";
864 }
865
866 SmallVector<Attribute> iteratorTypeAttrs;
867
868 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
869 auto maybeIteratorType = symbolizeIteratorType(s);
870 if (!maybeIteratorType.has_value())
871 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
872
873 iteratorTypeAttrs.push_back(
874 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
875 }
876 result.attributes.set(getIteratorTypesAttrName(result.name),
877 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
878
879 if (!result.attributes.get(getKindAttrName(result.name))) {
880 result.addAttribute(
881 getKindAttrName(result.name),
882 CombiningKindAttr::get(result.getContext(),
883 ContractionOp::getDefaultKind()));
884 }
885 if (masksInfo.empty())
886 return success();
887 if (masksInfo.size() != 2)
888 return parser.emitError(parser.getNameLoc(),
889 "expected zero or exactly 2 vector mask operands");
890 auto lhsType = llvm::cast<VectorType>(types[0]);
891 auto rhsType = llvm::cast<VectorType>(types[1]);
892 auto maskElementType = parser.getBuilder().getI1Type();
893 std::array<VectorType, 2> maskTypes = {
894 VectorType::Builder(lhsType).setElementType(maskElementType),
895 VectorType::Builder(rhsType).setElementType(maskElementType)};
896 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
897 return failure();
898 return success();
899}
900
901void ContractionOp::print(OpAsmPrinter &p) {
902 // TODO: Unify printing code with linalg ops.
903 auto attrNames = getTraitAttrNames();
904 llvm::StringSet<> traitAttrsSet;
905 traitAttrsSet.insert_range(attrNames);
907 for (auto attr : (*this)->getAttrs()) {
908 if (attr.getName() == getIteratorTypesAttrName()) {
909 auto iteratorTypes =
910 llvm::cast<ArrayAttr>(attr.getValue())
911 .getAsValueRange<IteratorTypeAttr, IteratorType>();
912 // Convert IteratorType enums into the string representation. This is
913 // needed, because tests still use the old format when 'iterator_types'
914 // attribute is represented as an array of strings.
915 // TODO: Remove this conversion once tests are fixed.
916 SmallVector<Attribute> iteratorTypeNames =
917 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) -> Attribute {
918 return StringAttr::get(getContext(), stringifyIteratorType(t));
919 });
920
921 attrs.emplace_back(getIteratorTypesAttrName(),
922 ArrayAttr::get(getContext(), iteratorTypeNames));
923 } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
924 attrs.push_back(attr);
925 }
926
927 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
928 p << " " << dictAttr << " " << getLhs() << ", ";
929 p << getRhs() << ", " << getAcc();
930
931 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
932 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
933 << getResultType();
934}
935
936static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
937 const std::vector<std::pair<int64_t, int64_t>> &map) {
938 for (auto &dimPair : map) {
939 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
940 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
941 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
942 return false;
943 }
944 return true;
945}
946
947static LogicalResult verifyOutputShape(
948 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
949 Type resType,
950 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
951 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
952 DenseSet<int64_t> lhsContractingDimSet;
953 DenseSet<int64_t> rhsContractingDimSet;
954 for (auto &dimPair : contractingDimMap) {
955 lhsContractingDimSet.insert(dimPair.first);
956 rhsContractingDimSet.insert(dimPair.second);
957 }
958 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
959 llvm::make_second_range(batchDimMap));
960
961 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
962 SmallVector<int64_t, 4> expectedResultDims;
963 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
964 if (lhsContractingDimSet.count(i) > 0)
965 continue;
966 expectedResultDims.push_back(lhsType.getDimSize(i));
967 }
968
969 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
970 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
971 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
972 continue;
973 expectedResultDims.push_back(rhsType.getDimSize(i));
974 }
975
976 // Verify 'expectedResultDims'.
977 if (expectedResultDims.empty()) {
978 // No batch or free dimension implies a scalar result.
979 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
980 return op.emitOpError("invalid accumulator/result vector shape");
981 } else {
982 // At least one batch or free dimension implies a vector result.
983 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
984 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
985 if (!resVectorType || !accVectorType)
986 return op.emitOpError("invalid accumulator/result vector shape");
987
988 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
989 // types fully define the result vector type. This assumes the affine maps
990 // are well-formed, which must have been verified already.
991 MLIRContext *ctx = op.getContext();
992 AffineMap lhsMap = op.getIndexingMapsArray()[0];
993 AffineMap rhsMap = op.getIndexingMapsArray()[1];
994 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
995 return op.emitOpError(
996 "expected all dimensions to be either a LHS or a RHS dimension");
998 for (auto pair :
999 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1000 VectorType v = pair.first;
1001 auto map = pair.second;
1002 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1003 unsigned pos = map.getDimPosition(idx);
1004 if (!extents[pos])
1005 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
1006 }
1007 }
1008 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
1009 return op.emitOpError("expected all dimensions to get an extent as "
1010 "either a LHS or a RHS dimension");
1011
1012 AffineMap resMap = op.getIndexingMapsArray()[2];
1013 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1014 /*symbolCount=*/0, extents, ctx);
1015 // Compose the resMap with the extentsMap, which is a constant map.
1016 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1017 assert(llvm::all_of(expectedMap.getResults(),
1018 llvm::IsaPred<AffineConstantExpr>) &&
1019 "expected constant extent along all dimensions.");
1020 // Extract the expected shape and build the type.
1021 auto expectedShape =
1022 llvm::map_to_vector<4>(expectedMap.getResults(), [](AffineExpr e) {
1023 return cast<AffineConstantExpr>(e).getValue();
1024 });
1025 auto expected =
1026 VectorType::get(expectedShape, resVectorType.getElementType(),
1027 resVectorType.getScalableDims());
1028 if (resVectorType != expected || accVectorType != expected)
1029 return op.emitOpError(
1030 "invalid accumulator/result vector shape, expected: ")
1031 << expected;
1032 }
1033 return success();
1034}
1035
1036LogicalResult ContractionOp::verify() {
1037 VectorType lhsType = getLhsType();
1038 VectorType rhsType = getRhsType();
1039 Type accType = getAccType();
1040 Type resType = getResultType();
1041
1042 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1043 if (!lhsType.getElementType().isSignlessInteger())
1044 return emitOpError("only supports signless integer types");
1045 }
1046
1047 // Verify that an indexing map was specified for each vector operand.
1048 if (getIndexingMapsArray().size() != 3)
1049 return emitOpError("expected an indexing map for each vector operand");
1050
1051 // Verify that each index map has 'numIterators' inputs, no symbols, and
1052 // that the number of map outputs equals the rank of its associated
1053 // vector operand.
1054 unsigned numIterators = getIteratorTypes().getValue().size();
1055 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1056 auto index = it.index();
1057 auto map = it.value();
1058 if (map.getNumSymbols() != 0)
1059 return emitOpError("expected indexing map ")
1060 << index << " to have no symbols";
1061 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1062 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1063 // Verify that the map has the right number of inputs, outputs, and indices.
1064 // This also correctly accounts for (..) -> () for rank-0 results.
1065 if (map.getNumDims() != numIterators)
1066 return emitOpError("expected indexing map ")
1067 << index << " to have " << numIterators << " number of inputs";
1068 if (map.getNumResults() != rank)
1069 return emitOpError("expected indexing map ")
1070 << index << " to have " << rank << " number of outputs";
1071 if (!map.isProjectedPermutation())
1072 return emitOpError("expected indexing map ")
1073 << index << " to be a projected permutation of its inputs";
1074 }
1075
1076 auto contractingDimMap = getContractingDimMap();
1077 auto batchDimMap = getBatchDimMap();
1078
1079 // Verify at least one contracting dimension pair was specified.
1080 if (contractingDimMap.empty())
1081 return emitOpError("expected at least one contracting dimension pair");
1082
1083 // Verify contracting dimension map was properly constructed.
1084 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1085 return emitOpError("invalid contracting dimension map");
1086
1087 // Verify batch dimension map was properly constructed.
1088 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1089 return emitOpError("invalid batch dimension map");
1090
1091 // Verify 'accType' and 'resType' shape.
1092 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1093 contractingDimMap, batchDimMap)))
1094 return failure();
1095
1096 if (!getKindAttr()) {
1097 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
1098 "'vector.kind<add>')");
1099 }
1100
1101 // Verify supported combining kind.
1102 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1103 auto elementType = vectorType ? vectorType.getElementType() : resType;
1104 if (!isSupportedCombiningKind(getKind(), elementType))
1105 return emitOpError("unsupported contraction type");
1106
1107 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1108 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1109}
1110
1111// MaskableOpInterface methods.
1112
1113/// Returns the mask type expected by this operation. Mostly used for
1114/// verification purposes. It requires the operation to be vectorized."
1115Type ContractionOp::getExpectedMaskType() {
1116 auto indexingMaps = this->getIndexingMapsArray();
1117 AffineMap lhsIdxMap = indexingMaps[0];
1118 AffineMap rhsIdxMap = indexingMaps[1];
1119 VectorType lhsType = this->getLhsType();
1120 VectorType rhsType = this->getRhsType();
1121
1122 unsigned numVecDims = lhsIdxMap.getNumDims();
1123 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1124 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1125
1126 // Using the information in the indexing maps, extract the size of each
1127 // dimension in the vector.contract operation from the two input operands.
1128 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1129 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1130 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1131 lhsType.getScalableDims()[dimIdx];
1132 }
1133 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1134 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1135 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1136 rhsType.getScalableDims()[dimIdx];
1137 }
1138
1139 assert(ShapedType::isStaticShape(maskShape) &&
1140 "Mask shape couldn't be computed");
1141
1142 return VectorType::get(maskShape,
1143 IntegerType::get(lhsType.getContext(), /*width=*/1),
1144 maskShapeScalableDims);
1145}
1146
1147SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1148 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1149 getIteratorTypesAttrName(), getKindAttrName()};
1150}
1151
1153 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1154 if (targetExpr == map.getResult(i))
1155 return i;
1156 return -1;
1157}
1158
1159static std::vector<std::pair<int64_t, int64_t>>
1160getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1161 IteratorType targetIteratorType, MLIRContext *context) {
1162 std::vector<std::pair<int64_t, int64_t>> dimMap;
1163 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1164 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1165 if (iteratorType != targetIteratorType)
1166 continue;
1167 // Search lhs/rhs map results for 'targetExpr'.
1168 auto targetExpr = getAffineDimExpr(it.index(), context);
1169 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1170 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1171 if (lhsDim >= 0 && rhsDim >= 0)
1172 dimMap.emplace_back(lhsDim, rhsDim);
1173 }
1174 return dimMap;
1175}
1176
1177void ContractionOp::getIterationBounds(
1178 SmallVectorImpl<int64_t> &iterationBounds) {
1179 auto lhsShape = getLhsType().getShape();
1180 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1181 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1182 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1183 // Search lhs/rhs map results for 'targetExpr'.
1184 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1185 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1186 if (iteratorType == IteratorType::reduction) {
1187 // Get reduction dim size from lhs shape (same size in rhsShape).
1188 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1189 assert(lhsDimIndex >= 0);
1190 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1191 continue;
1192 }
1193 // Get parallel dimension size from result shape.
1194 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1195 assert(resDimIndex >= 0);
1196 assert(resVectorType != nullptr);
1197 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1198 }
1199}
1200
1201void ContractionOp::getIterationIndexMap(
1202 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1203 unsigned numMaps = getIndexingMapsArray().size();
1204 iterationIndexMap.resize(numMaps);
1205 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1206 auto index = it.index();
1207 auto map = it.value();
1208 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1209 auto dim = cast<AffineDimExpr>(map.getResult(i));
1210 iterationIndexMap[index][dim.getPosition()] = i;
1211 }
1212 }
1213}
1214
1215std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1216 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1217 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1218 getContext());
1219}
1220
1221std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1222 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1223 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1224 getContext());
1225}
1226
1227std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1229 getIterationBounds(shape);
1230 return shape;
1231}
1232
1233/// Return a fused vector::ContractionOp which represents a patterns such as:
1234///
1235/// ```mlir
1236/// %c0 = vector.constant 0: ...
1237/// %c = vector.contract %a, %b, %c0: ...
1238/// %e = add %c, %d: ...
1239/// ```
1240///
1241/// by:
1242///
1243/// ```mlir
1244/// %e = vector.contract %a, %b, %d: ...
1245/// ```
1246///
1247/// Return null if the canonicalization does not apply.
1248// TODO: This should be a folding of Add into Contract in core but while they
1249// live in different dialects, it is not possible without unnatural
1250// dependencies.
1251template <typename AddOpType>
1252struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1253 using OpRewritePattern<AddOpType>::OpRewritePattern;
1254
1255 LogicalResult matchAndRewrite(AddOpType addOp,
1256 PatternRewriter &rewriter) const override {
1257 auto canonicalize = [&](Value maybeContraction,
1258 Value otherOperand) -> vector::ContractionOp {
1259 vector::ContractionOp contractionOp =
1260 dyn_cast_or_null<vector::ContractionOp>(
1261 maybeContraction.getDefiningOp());
1262 if (!contractionOp)
1263 return vector::ContractionOp();
1264 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1265 contractionOp.getAcc().getDefiningOp())) {
1266 if (maybeZero.getValue() ==
1267 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1268 IRMapping bvm;
1269 bvm.map(contractionOp.getAcc(), otherOperand);
1270 auto newContraction =
1271 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1272 rewriter.replaceOp(addOp, newContraction.getResult());
1273 return newContraction;
1274 }
1275 }
1276 return vector::ContractionOp();
1277 };
1278
1279 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1280 vector::ContractionOp contract = canonicalize(a, b);
1281 contract = contract ? contract : canonicalize(b, a);
1282 return contract ? success() : failure();
1283 }
1284};
1285
1286void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1287 MLIRContext *context) {
1290}
1291
1292// Returns `true` if `index` is either within [0, maxIndex) or equal to
1293// `poisonValue`.
1295 int64_t maxIndex) {
1296 return index == poisonValue || (index >= 0 && index < maxIndex);
1297}
1298
1299//===----------------------------------------------------------------------===//
1300// ExtractOp
1301//===----------------------------------------------------------------------===//
1302
1303void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1304 SetIntRangeFn setResultRanges) {
1305 setResultRanges(getResult(), argRanges.front());
1306}
1307
1308void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1309 Value source) {
1310 auto vectorTy = cast<VectorType>(source.getType());
1311 build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1312}
1313
1314void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1315 Value source, int64_t position) {
1316 build(builder, result, source, ArrayRef<int64_t>{position});
1317}
1318
1319void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1320 Value source, OpFoldResult position) {
1321 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1322}
1323
1324void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1325 Value source, ArrayRef<int64_t> position) {
1326 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1327 builder.getDenseI64ArrayAttr(position));
1328}
1329
1330void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1331 Value source, ArrayRef<OpFoldResult> position) {
1332 SmallVector<int64_t> staticPos;
1333 SmallVector<Value> dynamicPos;
1334 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1335 build(builder, result, source, dynamicPos,
1336 builder.getDenseI64ArrayAttr(staticPos));
1337}
1338
1339LogicalResult
1340ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1341 ExtractOp::Adaptor adaptor,
1342 SmallVectorImpl<Type> &inferredReturnTypes) {
1343 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1344 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1345 vectorType.getRank()) {
1346 inferredReturnTypes.push_back(vectorType.getElementType());
1347 } else {
1348 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1349 vectorType.getRank());
1350 inferredReturnTypes.push_back(VectorType::get(
1351 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1352 vectorType.getScalableDims().drop_front(n)));
1353 }
1354 return success();
1355}
1356
1357bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1358 // Allow extracting 1-element vectors instead of scalars.
1359 auto isCompatible = [](TypeRange l, TypeRange r) {
1360 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1361 return vectorType && vectorType.getShape().equals({1}) &&
1362 vectorType.getElementType() == r.front();
1363 };
1364 if (l.size() == 1 && r.size() == 1 &&
1365 (isCompatible(l, r) || isCompatible(r, l)))
1366 return true;
1367 return l == r;
1368}
1369
1370LogicalResult vector::ExtractOp::verify() {
1371 if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1372 if (resTy.getRank() == 0)
1373 return emitError(
1374 "expected a scalar instead of a 0-d vector as the result type");
1375
1376 // Note: This check must come before getMixedPosition() to prevent a crash.
1377 auto dynamicMarkersCount =
1378 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1379 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1380 return emitOpError(
1381 "mismatch between dynamic and static positions (kDynamic marker but no "
1382 "corresponding dynamic position) -- this can only happen due to an "
1383 "incorrect fold/rewrite");
1384 auto position = getMixedPosition();
1385 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1386 return emitOpError(
1387 "expected position attribute of rank no greater than vector rank");
1388 for (auto [idx, pos] : llvm::enumerate(position)) {
1389 if (auto attr = dyn_cast<Attribute>(pos)) {
1390 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1392 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1393 return emitOpError("expected position attribute #")
1394 << (idx + 1)
1395 << " to be a non-negative integer smaller than the "
1396 "corresponding vector dimension or poison (-1)";
1397 }
1398 }
1399 }
1400 return success();
1401}
1402
1403template <typename IntType>
1405 return llvm::map_to_vector<4>(
1406 arrayAttr.getAsRange<IntegerAttr>(),
1407 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1408}
1409
1410/// Fold the result of chains of ExtractOp in place by simply concatenating the
1411/// positions.
1412static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1413 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1414 return failure();
1415
1416 // TODO: Canonicalization for dynamic position not implemented yet.
1417 if (extractOp.hasDynamicPosition())
1418 return failure();
1419
1420 SmallVector<int64_t> globalPosition;
1421 ExtractOp currentOp = extractOp;
1422 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1423 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1424 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1425 currentOp = nextOp;
1426 // TODO: Canonicalization for dynamic position not implemented yet.
1427 if (currentOp.hasDynamicPosition())
1428 return failure();
1429 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1430 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1431 }
1432 extractOp.setOperand(0, currentOp.getSource());
1433 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1434 OpBuilder b(extractOp.getContext());
1435 std::reverse(globalPosition.begin(), globalPosition.end());
1436 extractOp.setStaticPosition(globalPosition);
1437 return success();
1438}
1439
1440namespace {
1441/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1442/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1443/// Compose TransposeOp permutations as we walk back.
1444/// This helper class keeps an updated extraction position `extractPosition`
1445/// with extra trailing sentinels.
1446/// The sentinels encode the internal transposition status of the result vector.
1447/// As we iterate, extractPosition is permuted and updated.
1448class ExtractFromInsertTransposeChainState {
1449public:
1450 ExtractFromInsertTransposeChainState(ExtractOp e);
1451
1452 /// Iterate over producing insert and transpose ops until we find a fold.
1453 Value fold();
1454
1455private:
1456 /// Return true if the vector at position `a` is contained within the vector
1457 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1458 /// is a prefix of `b`.
1459 template <typename ContainerA, typename ContainerB>
1460 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1461 return a.size() <= b.size() &&
1462 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1463 }
1464
1465 /// Return true if the vector at position `a` intersects the vector at
1466 /// position `b`. Under insert/extract semantics, this is the same as equality
1467 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1468 /// Comparison is on the common prefix (i.e. zip).
1469 template <typename ContainerA, typename ContainerB>
1470 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1471 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1472 if (elemA < 0 || elemB < 0)
1473 continue;
1474 if (elemA != elemB)
1475 return false;
1476 }
1477 return true;
1478 }
1479
1480 /// Folding is only possible in the absence of an internal permutation in the
1481 /// result vector.
1482 bool canFold() {
1483 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1484 }
1485
1486 // Helper to get the next defining op of interest.
1487 void updateStateForNextIteration(Value v) {
1488 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1489 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1490 };
1491
1492 // Case 1. If we hit a transpose, just compose the map and iterate.
1493 // Invariant: insert + transpose do not change rank, we can always compose.
1494 LogicalResult handleTransposeOp();
1495
1496 // Case 2: the insert position matches extractPosition exactly, early return.
1497 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1498
1499 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1500 /// portion of the source of the insert.
1501 /// Example:
1502 /// ```
1503 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1504 /// // extractPosition == [1, 2, 3]
1505 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1506 /// // can fold to vector.extract %source[0, 3]
1507 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1508 /// ```
1509 /// To traverse through %source, we need to set the leading dims to 0 and
1510 /// drop the extra leading dims.
1511 /// This method updates the internal state.
1512 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1513
1514 /// Try to fold in place to extract(source, extractPosition) and return the
1515 /// folded result. Return null if folding is not possible (e.g. due to an
1516 /// internal transposition in the result).
1517 Value tryToFoldExtractOpInPlace(Value source);
1518
1519 ExtractOp extractOp;
1520 int64_t vectorRank;
1521 int64_t extractedRank;
1522
1523 InsertOp nextInsertOp;
1524 TransposeOp nextTransposeOp;
1525
1526 /// Sentinel values that encode the internal permutation status of the result.
1527 /// They are set to (-1, ... , -k) at the beginning and appended to
1528 /// `extractPosition`.
1529 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1530 /// ensure that there is no internal transposition.
1531 /// Internal transposition cannot be accounted for with a folding pattern.
1532 // TODO: We could relax the internal transposition with an extra transposition
1533 // operation in a future canonicalizer.
1534 SmallVector<int64_t> sentinels;
1535 SmallVector<int64_t> extractPosition;
1536};
1537} // namespace
1538
1539ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1540 ExtractOp e)
1541 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1542 extractedRank(extractOp.getNumIndices()) {
1543 assert(vectorRank >= extractedRank && "Extracted position overflow");
1544 sentinels.reserve(vectorRank - extractedRank);
1545 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1546 sentinels.push_back(-(i + 1));
1547 extractPosition.assign(extractOp.getStaticPosition().begin(),
1548 extractOp.getStaticPosition().end());
1549 llvm::append_range(extractPosition, sentinels);
1550}
1551
1552// Case 1. If we hit a transpose, just compose the map and iterate.
1553// Invariant: insert + transpose do not change rank, we can always compose.
1554LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1555 // TODO: Canonicalization for dynamic position not implemented yet.
1556 if (extractOp.hasDynamicPosition())
1557 return failure();
1558
1559 if (!nextTransposeOp)
1560 return failure();
1562 nextTransposeOp.getPermutation(), extractOp.getContext()));
1564 return success();
1565}
1566
1567// Case 2: the insert position matches extractPosition exactly, early return.
1568LogicalResult
1569ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1570 Value &res) {
1571 // TODO: Canonicalization for dynamic position not implemented yet.
1572 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1573 return failure();
1574
1575 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1576 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1577 return failure();
1578 // Case 2.a. early-exit fold.
1579 res = nextInsertOp.getValueToStore();
1580 // Case 2.b. if internal transposition is present, canFold will be false.
1581 return success(canFold());
1582}
1583
1584/// Case 3: if inserted position is a prefix of extractPosition,
1585/// extract a portion of the source of the insertion.
1586/// This method updates the internal state.
1587LogicalResult
1588ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1589 // TODO: Canonicalization for dynamic position not implemented yet.
1590 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1591 return failure();
1592
1593 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1594 if (!isContainedWithin(insertedPos, extractPosition))
1595 return failure();
1596 // Set leading dims to zero.
1597 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1598 // Drop extra leading dims.
1599 extractPosition.erase(extractPosition.begin(),
1600 extractPosition.begin() + insertedPos.size());
1601 extractedRank = extractPosition.size() - sentinels.size();
1602 // Case 3.a. early-exit fold (break and delegate to post-while path).
1603 res = nextInsertOp.getValueToStore();
1604 // Case 3.b. if internal transposition is present, canFold will be false.
1605 return success();
1606}
1607
1608/// Try to fold in place to extract(source, extractPosition) and return the
1609/// folded result. Return null if folding is not possible (e.g. due to an
1610/// internal transposition in the result).
1611Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1612 Value source) {
1613 // TODO: Canonicalization for dynamic position not implemented yet.
1614 if (extractOp.hasDynamicPosition())
1615 return Value();
1616
1617 // If we can't fold (either internal transposition, or nothing to fold), bail.
1618 bool nothingToFold = (source == extractOp.getSource());
1619 if (nothingToFold || !canFold())
1620 return Value();
1621
1622 // Otherwise, fold by updating the op inplace and return its result.
1623 OpBuilder b(extractOp.getContext());
1624 extractOp.setStaticPosition(
1625 ArrayRef(extractPosition).take_front(extractedRank));
1626 extractOp.getSourceMutable().assign(source);
1627 return extractOp.getResult();
1628}
1629
1630/// Iterate over producing insert and transpose ops until we find a fold.
1631Value ExtractFromInsertTransposeChainState::fold() {
1632 // TODO: Canonicalization for dynamic position not implemented yet.
1633 if (extractOp.hasDynamicPosition())
1634 return Value();
1635
1636 Value valueToExtractFrom = extractOp.getSource();
1637 updateStateForNextIteration(valueToExtractFrom);
1638 while (nextInsertOp || nextTransposeOp) {
1639 // Case 1. If we hit a transpose, just compose the map and iterate.
1640 // Invariant: insert + transpose do not change rank, we can always compose.
1641 if (succeeded(handleTransposeOp())) {
1642 valueToExtractFrom = nextTransposeOp.getVector();
1643 updateStateForNextIteration(valueToExtractFrom);
1644 continue;
1645 }
1646
1647 Value result;
1648 // Case 2: the position match exactly.
1649 if (succeeded(handleInsertOpWithMatchingPos(result)))
1650 return result;
1651
1652 // Case 3: if the inserted position is a prefix of extractPosition, we can
1653 // just extract a portion of the source of the insert.
1654 if (succeeded(handleInsertOpWithPrefixPos(result)))
1655 return tryToFoldExtractOpInPlace(result);
1656
1657 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1658 // values. This is a more difficult case and we bail.
1659 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1660 if (isContainedWithin(extractPosition, insertedPos) ||
1661 intersectsWhereNonNegative(extractPosition, insertedPos))
1662 return Value();
1663
1664 // Case 5: No intersection, we forward the extract to insertOp.dest().
1665 valueToExtractFrom = nextInsertOp.getDest();
1666 updateStateForNextIteration(valueToExtractFrom);
1667 }
1668 // If after all this we can fold, go for it.
1669 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1670}
1671
1672/// Returns true if the operation has a 0-D vector type operand or result.
1674 auto hasZeroDimVectorType = [](Type type) -> bool {
1675 auto vecType = dyn_cast<VectorType>(type);
1676 return vecType && vecType.getRank() == 0;
1677 };
1678
1679 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1680 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1681}
1682
1683/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1684/// considered to be 'broadcastlike'.
1685static bool isBroadcastLike(Operation *op) {
1686 if (isa<BroadcastOp>(op))
1687 return true;
1688
1689 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1690 if (!shapeCast)
1691 return false;
1692
1693 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1694 // Checking that the destination shape has a prefix of 1s is not sufficient,
1695 // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1696 // is that the source shape is a suffix of the destination shape.
1697 VectorType srcType = shapeCast.getSourceVectorType();
1698 ArrayRef<int64_t> srcShape = srcType.getShape();
1699 uint64_t srcRank = srcType.getRank();
1700 ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1701 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1702}
1703
1704/// Fold extract(broadcast(X)) to either extract(X) or just X.
1705///
1706/// Example:
1707///
1708/// broadcast extract [1][2]
1709/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1710///
1711/// becomes
1712/// extract [1]
1713/// (3,4) -------------------------------------> (4)
1714///
1715///
1716/// The variable names used in this implementation correspond to the above
1717/// shapes as,
1718///
1719/// - (3, 4) is `input` shape.
1720/// - (2, 3, 4) is `broadcast` shape.
1721/// - (4) is `extract` shape.
1722///
1723/// This folding is possible when the suffix of `input` shape is the same as
1724/// `extract` shape.
1725static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1726
1727 Operation *defOp = extractOp.getSource().getDefiningOp();
1728 if (!defOp || !isBroadcastLike(defOp))
1729 return Value();
1730
1731 Value input = defOp->getOperand(0);
1732
1733 // Replace extract(broadcast(X)) with X
1734 if (extractOp.getType() == input.getType())
1735 return input;
1736
1737 // Get required types and ranks in the chain
1738 // input -> broadcast -> extract
1739 // (scalars are treated as rank-0).
1740 auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1741 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1742 unsigned inputRank = inputType ? inputType.getRank() : 0;
1743 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1744 unsigned extractRank = extractType ? extractType.getRank() : 0;
1745
1746 // Cannot do without the broadcast if overall the rank increases.
1747 if (extractRank > inputRank)
1748 return Value();
1749
1750 // The above condition guarantees that input is a vector.
1751 assert(inputType && "input must be a vector type because of previous checks");
1752 ArrayRef<int64_t> inputShape = inputType.getShape();
1753
1754 // In the case where there is a broadcast dimension in the suffix, it is not
1755 // possible to replace extract(broadcast(X)) with extract(X). Example:
1756 //
1757 // broadcast extract
1758 // (1) --------> (3,4) ------> (4)
1759 if (extractType &&
1760 extractType.getShape() != inputShape.take_back(extractRank))
1761 return Value();
1762
1763 // Replace extract(broadcast(X)) with extract(X).
1764 // First, determine the new extraction position.
1765 unsigned deltaOverall = inputRank - extractRank;
1766 unsigned deltaBroadcast = broadcastRank - inputRank;
1767 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1768 SmallVector<OpFoldResult> newPositions(deltaOverall);
1769 IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1770 for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1771 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1772 }
1773 auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1774 extractOp->setOperands(
1775 llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1776 extractOp.setStaticPosition(staticPos);
1777 return extractOp.getResult();
1778}
1779
1780/// Fold extractOp coming from ShuffleOp.
1781///
1782/// Example:
1783///
1784/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1785/// : vector<8xf32>, vector<8xf32>
1786/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1787/// ->
1788/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1789///
1790static Value foldExtractFromShuffle(ExtractOp extractOp) {
1791 // Dynamic positions are not folded as the resulting code would be more
1792 // complex than the input code.
1793 if (extractOp.hasDynamicPosition())
1794 return Value();
1795
1796 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1797 if (!shuffleOp)
1798 return Value();
1799
1800 // TODO: 0-D or multi-dimensional vectors not supported yet.
1801 if (shuffleOp.getResultVectorType().getRank() != 1)
1802 return Value();
1803
1804 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1805 auto shuffleMask = shuffleOp.getMask();
1806 int64_t extractIdx = extractOp.getStaticPosition()[0];
1807 int64_t shuffleIdx = shuffleMask[extractIdx];
1808
1809 // Find the shuffled vector to extract from based on the shuffle index.
1810 if (shuffleIdx < inputVecSize) {
1811 extractOp.setOperand(0, shuffleOp.getV1());
1812 extractOp.setStaticPosition({shuffleIdx});
1813 } else {
1814 extractOp.setOperand(0, shuffleOp.getV2());
1815 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1816 }
1817
1818 return extractOp.getResult();
1819}
1820
1821// Fold extractOp with source coming from ShapeCast op.
1822static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1823 // TODO: Canonicalization for dynamic position not implemented yet.
1824 if (extractOp.hasDynamicPosition())
1825 return Value();
1826
1827 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1828 if (!shapeCastOp)
1829 return Value();
1830
1831 // Get the nth dimension size starting from lowest dimension.
1832 auto getDimReverse = [](VectorType type, int64_t n) {
1833 return type.getShape().take_back(n + 1).front();
1834 };
1835 int64_t destinationRank =
1836 llvm::isa<VectorType>(extractOp.getType())
1837 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1838 : 0;
1839 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1840 return Value();
1841 if (destinationRank > 0) {
1842 auto destinationType =
1843 llvm::cast<VectorType>(extractOp.getResult().getType());
1844 for (int64_t i = 0; i < destinationRank; i++) {
1845 // The lowest dimension of the destination must match the lowest
1846 // dimension of the shapecast op source.
1847 // TODO: This case could be support in a canonicalization pattern.
1848 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1849 getDimReverse(destinationType, i))
1850 return Value();
1851 }
1852 }
1853 // Extract the strides associated with the extract op vector source. Then use
1854 // this to calculate a linearized position for the extract.
1855 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1856 std::reverse(extractedPos.begin(), extractedPos.end());
1858 int64_t stride = 1;
1859 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1860 strides.push_back(stride);
1861 stride *=
1862 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1863 }
1864
1865 int64_t position = linearize(extractedPos, strides);
1866 // Then extract the strides associated to the shapeCast op vector source and
1867 // delinearize the position using those strides.
1868 SmallVector<int64_t, 4> newStrides;
1869 int64_t numDimension =
1870 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1871 stride = 1;
1872 for (int64_t i = 0; i < numDimension; i++) {
1873 newStrides.push_back(stride);
1874 stride *=
1875 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1876 }
1877 std::reverse(newStrides.begin(), newStrides.end());
1878 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1879 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1880 OpBuilder b(extractOp.getContext());
1881 extractOp.setStaticPosition(newPosition);
1882 extractOp.setOperand(0, shapeCastOp.getSource());
1883 return extractOp.getResult();
1884}
1885
1886/// Fold an ExtractOp from ExtractStridedSliceOp.
1887static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1888 // TODO: Canonicalization for dynamic position not implemented yet.
1889 if (extractOp.hasDynamicPosition())
1890 return Value();
1891
1892 auto extractStridedSliceOp =
1893 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1894 if (!extractStridedSliceOp)
1895 return Value();
1896
1897 // 0-D vectors not supported.
1898 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1899 if (hasZeroDimVectors(extractStridedSliceOp))
1900 return Value();
1901
1902 // Return if 'extractStridedSliceOp' has non-unit strides.
1903 if (extractStridedSliceOp.hasNonUnitStrides())
1904 return Value();
1905
1906 // Trim offsets for dimensions fully extracted.
1907 auto sliceOffsets =
1908 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1909 while (!sliceOffsets.empty()) {
1910 size_t lastOffset = sliceOffsets.size() - 1;
1911 if (sliceOffsets.back() != 0 ||
1912 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1913 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1914 break;
1915 sliceOffsets.pop_back();
1916 }
1917 unsigned destinationRank = 0;
1918 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1919 destinationRank = vecType.getRank();
1920 // The dimensions of the result need to be untouched by the
1921 // extractStridedSlice op.
1922 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1923 sliceOffsets.size())
1924 return Value();
1925
1926 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1927 assert(extractedPos.size() >= sliceOffsets.size());
1928 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1929 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1930 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1931
1932 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1933 OpBuilder b(extractOp.getContext());
1934 extractOp.setStaticPosition(extractedPos);
1935 return extractOp.getResult();
1936}
1937
1938/// Fold extract_op fed from a chain of insertStridedSlice ops.
1939static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1940 // TODO: Canonicalization for dynamic position not implemented yet.
1941 if (extractOp.hasDynamicPosition())
1942 return Value();
1943
1944 int64_t destinationRank =
1945 llvm::isa<VectorType>(extractOp.getType())
1946 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1947 : 0;
1948 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1949 if (!insertOp)
1950 return Value();
1951
1952 // 0-D vectors not supported.
1953 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1954 if (hasZeroDimVectors(insertOp))
1955 return Value();
1956
1957 while (insertOp) {
1958 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1959 insertOp.getSourceVectorType().getRank();
1960 if (destinationRank > insertOp.getSourceVectorType().getRank())
1961 return Value();
1962 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1963 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1964
1965 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1966 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1967 }))
1968 return Value();
1969 bool disjoint = false;
1970 SmallVector<int64_t, 4> offsetDiffs;
1971 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1972 int64_t start = insertOffsets[dim];
1973 int64_t size =
1974 (dim < insertRankDiff)
1975 ? 1
1976 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1977 int64_t end = start + size;
1978 int64_t offset = extractOffsets[dim];
1979 // Check if the start of the extract offset is in the interval inserted.
1980 if (start <= offset && offset < end) {
1981 if (dim >= insertRankDiff)
1982 offsetDiffs.push_back(offset - start);
1983 continue;
1984 }
1985 disjoint = true;
1986 break;
1987 }
1988 // The extract element chunk overlap with the vector inserted.
1989 if (!disjoint) {
1990 // If any of the inner dimensions are only partially inserted we have a
1991 // partial overlap.
1992 int64_t srcRankDiff =
1993 insertOp.getSourceVectorType().getRank() - destinationRank;
1994 for (int64_t i = 0; i < destinationRank; i++) {
1995 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1996 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1997 insertRankDiff))
1998 return Value();
1999 }
2000 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2001 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2002 OpBuilder b(extractOp.getContext());
2003 extractOp.setStaticPosition(offsetDiffs);
2004 return extractOp.getResult();
2005 }
2006 // If the chunk extracted is disjoint from the chunk inserted, keep
2007 // looking in the insert chain.
2008 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2009 }
2010 return Value();
2011}
2012
2013/// Try to fold the extraction of a scalar from a vector defined by
2014/// vector.from_elements. E.g.:
2015///
2016/// %0 = vector.from_elements %a, %b : vector<2xf32>
2017/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2018/// ==> fold to %a
2019static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2020 // Dynamic extractions cannot be folded.
2021 if (extractOp.hasDynamicPosition())
2022 return {};
2023
2024 // Look for extract(from_elements).
2025 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2026 if (!fromElementsOp)
2027 return {};
2028
2029 // Scalable vectors are not supported.
2030 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2031 if (vecType.isScalable())
2032 return {};
2033
2034 // Only extractions of scalars are supported.
2035 int64_t rank = vecType.getRank();
2036 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2037 if (extractOp.getType() != vecType.getElementType())
2038 return {};
2039 assert(static_cast<int64_t>(indices.size()) == rank &&
2040 "unexpected number of indices");
2041
2042 // Compute flattened/linearized index and fold to operand.
2043 int flatIndex = 0;
2044 int stride = 1;
2045 for (int i = rank - 1; i >= 0; --i) {
2046 flatIndex += indices[i] * stride;
2047 stride *= vecType.getDimSize(i);
2048 }
2049 return fromElementsOp.getElements()[flatIndex];
2050}
2051
2052/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2053/// then fold it.
2054template <typename OpType, typename AdaptorType>
2055static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2056 SmallVectorImpl<Value> &operands) {
2057 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2058 OperandRange dynamicPosition = op.getDynamicPosition();
2059 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2061 if constexpr (std::is_same_v<OpType, ExtractOp>)
2062 vectorShape = op.getSourceVectorType().getShape();
2063 else
2064 vectorShape = op.getDestVectorType().getShape();
2065
2066 // If the dynamic operands is empty, it is returned directly.
2067 if (!dynamicPosition.size())
2068 return {};
2069
2070 // `index` is used to iterate over the `dynamicPosition`.
2071 unsigned index = 0;
2072
2073 // `opChange` is a flag. If it is true, it means to update `op` in place.
2074 bool opChange = false;
2075 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2076 if (ShapedType::isStatic(staticPosition[i]))
2077 continue;
2078 Attribute positionAttr = dynamicPositionAttr[index];
2079 Value position = dynamicPosition[index++];
2080 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2081 int64_t value = attr.getInt();
2082 // Do not fold if the value is out of bounds (-1 signifies a poison
2083 // value rather than OOB index).
2084 if (value >= -1 && value < vectorShape[i]) {
2085 staticPosition[i] = attr.getInt();
2086 opChange = true;
2087 continue;
2088 }
2089 }
2090 operands.push_back(position);
2091 }
2092
2093 if (opChange) {
2094 op.setStaticPosition(staticPosition);
2095 op.getOperation()->setOperands(operands);
2096 // Return the original result to indicate an in-place folding happened.
2097 return op.getResult();
2098 }
2099 return {};
2100}
2101
2102/// Fold an insert or extract operation into an poison value when a poison index
2103/// is found at any dimension of the static position.
2105 ArrayRef<int64_t> staticPos,
2106 int64_t poisonVal) {
2107 if (!is_contained(staticPos, poisonVal))
2108 return {};
2109
2110 return ub::PoisonAttr::get(context);
2111}
2112
2113/// Fold a vector extract from is a poison source.
2115 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2116 return srcAttr;
2117
2118 return {};
2119}
2120
2121/// Fold a vector extract extracting from a DenseElementsAttr.
2123 Attribute srcAttr) {
2124 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2125 if (!denseAttr) {
2126 return {};
2127 }
2128
2129 if (denseAttr.isSplat()) {
2130 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2131 if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2132 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2133 return newAttr;
2134 }
2135
2136 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2137 if (vecTy.isScalable())
2138 return {};
2139
2140 if (extractOp.hasDynamicPosition()) {
2141 return {};
2142 }
2143
2144 // Materializing subsets of a large constant array can generally lead to
2145 // explosion in IR size because of different combination of subsets that
2146 // can exist. However, vector.extract is a restricted form of subset
2147 // extract where you can only extract non-overlapping (or the same) subset for
2148 // a given rank of the subset. Because of this property, the IR size can only
2149 // increase at most by `rank * size(array)` from a single constant array being
2150 // extracted by multiple extracts.
2151
2152 // Calculate the linearized position of the continuous chunk of elements to
2153 // extract.
2154 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2155 copy(extractOp.getStaticPosition(), completePositions.begin());
2156 int64_t startPos =
2157 linearize(completePositions, computeStrides(vecTy.getShape()));
2158 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2159
2160 TypedAttr newAttr;
2161 if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2162 SmallVector<Attribute> elementValues(
2163 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2164 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2165 } else {
2166 newAttr = *denseValuesBegin;
2167 }
2168
2169 return newAttr;
2170}
2171
2172OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2173 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2174 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2175 // mismatch).
2176 if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2177 return getSource();
2178 if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2179 return res;
2180 // Fold `arith.constant` indices into the `vector.extract` operation.
2181 // Do not stop here as this fold may enable subsequent folds that require
2182 // constant indices.
2183 SmallVector<Value> operands = {getSource()};
2184 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2185
2186 if (auto res = foldPoisonIndexInsertExtractOp(
2187 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2188 return res;
2189 if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2190 return res;
2191 if (succeeded(foldExtractOpFromExtractChain(*this)))
2192 return getResult();
2193 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2194 return res;
2195 if (auto res = foldExtractFromBroadcast(*this))
2196 return res;
2197 if (auto res = foldExtractFromShuffle(*this))
2198 return res;
2199 if (auto res = foldExtractFromShapeCast(*this))
2200 return res;
2201 if (auto val = foldExtractFromExtractStrided(*this))
2202 return val;
2203 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2204 return val;
2205 if (auto val = foldScalarExtractFromFromElements(*this))
2206 return val;
2207
2208 return inplaceFolded;
2209}
2210
2211namespace {
2212
2213// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2214class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2215public:
2216 using Base::Base;
2217
2218 LogicalResult matchAndRewrite(ExtractOp extractOp,
2219 PatternRewriter &rewriter) const override {
2220
2221 Operation *defOp = extractOp.getSource().getDefiningOp();
2222 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2223 if (!defOp || !isBroadcastLike(defOp) || !outType)
2224 return failure();
2225
2226 Value source = defOp->getOperand(0);
2227 if (isBroadcastableTo(source.getType(), outType) !=
2228 BroadcastableToResult::Success)
2229 return failure();
2230
2231 rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2232 return success();
2233 }
2234};
2235
2236// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2237class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2238public:
2239 using Base::Base;
2240
2241 LogicalResult matchAndRewrite(ExtractOp extractOp,
2242 PatternRewriter &rewriter) const override {
2243 auto createMaskOp =
2244 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2245 if (!createMaskOp)
2246 return failure();
2247
2248 VectorType extractedMaskType =
2249 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2250
2251 if (!extractedMaskType)
2252 return failure();
2253
2254 auto maskOperands = createMaskOp.getOperands();
2255 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2256 VectorType maskType = createMaskOp.getVectorType();
2257
2258 bool containsUnknownDims = false;
2259 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2260
2261 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2262 dimIdx++) {
2263 int64_t pos = extractOpPos[dimIdx];
2264 Value operand = maskOperands[dimIdx];
2265 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2266 if (!constantOp) {
2267 // Bounds of this dim unknown.
2268 containsUnknownDims = true;
2269 continue;
2270 }
2271
2272 int64_t createMaskBound =
2273 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2274
2275 if (pos != ShapedType::kDynamic) {
2276 // If any position is outside the range from the `create_mask`, then the
2277 // extracted mask will be all-false.
2278 allFalse |= pos >= createMaskBound;
2279 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2280 // This dim is not all-true and since this is a dynamic index we don't
2281 // know if the extraction is within the true or false region.
2282 // Note: Zero dims have already handled via getMaskFormat().
2283 containsUnknownDims = true;
2284 }
2285 }
2286
2287 if (allFalse) {
2288 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2289 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2290 } else if (!containsUnknownDims) {
2291 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2292 extractOp, extractedMaskType,
2293 maskOperands.drop_front(extractOpPos.size()));
2294 } else {
2295 return failure();
2296 }
2297 return success();
2298 }
2299};
2300
2301// Folds extract(shape_cast(..)) into shape_cast when the total element count
2302// does not change.
2303LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2304 PatternRewriter &rewriter) {
2305 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2306 if (!castOp)
2307 return failure();
2308
2309 VectorType sourceType = castOp.getSourceVectorType();
2310 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2311 if (!targetType)
2312 return failure();
2313
2314 if (sourceType.getNumElements() != targetType.getNumElements())
2315 return failure();
2316
2317 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2318 castOp.getSource());
2319 return success();
2320}
2321
2322/// Try to canonicalize the extraction of a subvector from a vector defined by
2323/// vector.from_elements. E.g.:
2324///
2325/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2326/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2327/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2328LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2329 PatternRewriter &rewriter) {
2330 // Dynamic positions are not supported.
2331 if (extractOp.hasDynamicPosition())
2332 return failure();
2333
2334 // Scalar extracts are handled by the folder.
2335 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2336 if (!resultType)
2337 return failure();
2338
2339 // Look for extracts from a from_elements op.
2340 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2341 if (!fromElementsOp)
2342 return failure();
2343 VectorType inputType = fromElementsOp.getType();
2344
2345 // Scalable vectors are not supported.
2346 if (resultType.isScalable() || inputType.isScalable())
2347 return failure();
2348
2349 // Compute the position of first extracted element and flatten/linearize the
2350 // position.
2351 SmallVector<int64_t> firstElementPos =
2352 llvm::to_vector(extractOp.getStaticPosition());
2353 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2354 int flatIndex = 0;
2355 int stride = 1;
2356 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2357 flatIndex += firstElementPos[i] * stride;
2358 stride *= inputType.getDimSize(i);
2359 }
2360
2361 // Replace the op with a smaller from_elements op.
2362 rewriter.replaceOpWithNewOp<FromElementsOp>(
2363 extractOp, resultType,
2364 fromElementsOp.getElements().slice(flatIndex,
2365 resultType.getNumElements()));
2366 return success();
2367}
2368
2369/// Replace `vector.extract` with `vector.shape_cast`.
2370///
2371/// BEFORE:
2372/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2373/// AFTER:
2374/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2375///
2376/// The canonical form of vector operations that reshape vectors is shape_cast.
2377struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2378 using Base::Base;
2379 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2380 PatternRewriter &rewriter) const override {
2381 VectorType sourceType = extractOp.getSourceVectorType();
2382 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2383 if (!outType)
2384 return failure();
2385
2386 if (sourceType.getNumElements() != outType.getNumElements())
2387 return rewriter.notifyMatchFailure(
2388 extractOp, "extract to vector with fewer elements");
2389
2390 // Negative values in `position` means that the extacted value is poison.
2391 // There is a vector.extract folder for this.
2392 if (llvm::any_of(extractOp.getMixedPosition(),
2393 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2394 return rewriter.notifyMatchFailure(extractOp,
2395 "leaving for extract poison folder");
2396
2397 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2398 extractOp.getSource());
2399
2400 return success();
2401 }
2402};
2403
2404} // namespace
2405
2406void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2407 MLIRContext *context) {
2408 results
2409 .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
2410 context);
2411 results.add(foldExtractFromShapeCastToShapeCast);
2412 results.add(foldExtractFromFromElements);
2413}
2414
2416 SmallVectorImpl<int64_t> &results) {
2417 for (auto attr : arrayAttr)
2418 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2419}
2420
2421//===----------------------------------------------------------------------===//
2422// FmaOp
2423//===----------------------------------------------------------------------===//
2424
2425std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2426 return llvm::to_vector<4>(getVectorType().getShape());
2427}
2428
2429//===----------------------------------------------------------------------===//
2430// ToElementsOp
2431//===----------------------------------------------------------------------===//
2432
2433/// Returns true if all the `operands` are defined by `defOp`.
2434/// Otherwise, returns false.
2435static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2436 if (operands.empty())
2437 return false;
2438
2439 return llvm::all_of(operands, [&](Value operand) {
2440 Operation *currentDef = operand.getDefiningOp();
2441 return currentDef == defOp;
2442 });
2443}
2444
2445/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2446/// (%e0, %e1, ...). For example:
2447///
2448/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2449/// %1:3 = vector.to_elements %0 : vector<3xf32>
2450/// user_op %1#0, %1#1, %1#2
2451///
2452/// becomes:
2453///
2454/// user_op %a, %b, %c
2455///
2456static LogicalResult
2457foldToElementsFromElements(ToElementsOp toElementsOp,
2459 auto fromElementsOp =
2460 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2461 if (!fromElementsOp)
2462 return failure();
2463
2464 llvm::append_range(results, fromElementsOp.getElements());
2465 return success();
2466}
2467
2468/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2469///
2470/// Example:
2471/// %b = vector.broadcast %x : i32 to vector<3xf32>
2472/// %e:3 = vector.to_elements %b : vector<3xf32>
2473/// user_op %e#0, %e#1, %e#2
2474/// becomes:
2475/// user_op %x, %x, %x
2476///
2477/// The vector source case is handled by a canonicalization pattern.
2478static LogicalResult
2479foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2481 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2482 if (!bcastOp)
2483 return failure();
2484 // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2485 if (isa<VectorType>(bcastOp.getSource().getType()))
2486 return failure();
2487
2488 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2489
2490 Value scalar = bcastOp.getSource();
2491 results.assign(resultVecType.getNumElements(), scalar);
2492 return success();
2493}
2494
2495LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2496 SmallVectorImpl<OpFoldResult> &results) {
2497 if (succeeded(foldToElementsFromElements(*this, results)))
2498 return success();
2499 return foldToElementsOfBroadcast(*this, results);
2500}
2501
2502LogicalResult
2503ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2504 ToElementsOp::Adaptor adaptor,
2505 SmallVectorImpl<Type> &inferredReturnTypes) {
2506 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2507 Type elType = vecType.getElementType();
2508 inferredReturnTypes.append(vecType.getNumElements(), elType);
2509 return success();
2510}
2511
2512/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2513/// vector.
2514/// - Build `vector.to_elements %v` and remap each destination element to the
2515/// corresponding source element using broadcast rules (match or 1 →
2516/// replicate).
2517///
2518/// Example:
2519/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2520/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2521/// becomes:
2522/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2523/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2524/// // %src_elems#1, %src_elems#0, %src_elems#1
2525struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2526 using Base::Base;
2527
2528 LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2529 PatternRewriter &rewriter) const override {
2530 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2531 if (!bcastOp)
2532 return failure();
2533
2534 // Only handle broadcasts from a vector source here.
2535 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2536 if (!srcType)
2537 return failure();
2538
2539 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2540
2541 ArrayRef<int64_t> dstShape = dstType.getShape();
2542 ArrayRef<int64_t> srcShape = srcType.getShape();
2543
2544 int64_t dstRank = dstShape.size();
2545 int64_t srcRank = srcShape.size();
2546
2547 // Create elements for the broadcast source vector.
2548 auto srcElems = vector::ToElementsOp::create(
2549 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2550
2551 int64_t dstCount = llvm::product_of(dstShape);
2552
2553 SmallVector<Value> replacements;
2554 replacements.reserve(dstCount);
2555
2556 // For each element of the destination, determine which element of the
2557 // source should be used. We walk all destination positions using a single
2558 // counter, decode it into per-dimension indices, then build the matching
2559 // source position: use the same index where sizes match, and use 0 where
2560 // the source size is 1 (replication). This mapping is needed so we can
2561 // replace each result of to_elements with the corresponding element from
2562 // the broadcast source.
2563 // Inner-dimension stretch example:
2564 // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2565 // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2566 // becomes:
2567 // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2568 // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2569 // // %src_elems#1, %src_elems#0, %src_elems#1,
2570 // // %src_elems#2, %src_elems#3, %src_elems#2,
2571 // // %src_elems#3, %src_elems#2, %src_elems#3
2572
2573 // Row-major strides for the destination shape.
2574 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2575 // Row-major strides for the source shape.
2576 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2577 SmallVector<int64_t> dstIdx(dstRank);
2578 SmallVector<int64_t> srcIdx(srcRank);
2579 for (int64_t lin = 0; lin < dstCount; ++lin) {
2580 // Convert linear destination index to per-dimension indices.
2581 dstIdx = delinearize(lin, dstStrides);
2582 for (int64_t k = 0; k < srcRank; ++k)
2583 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2584 // Convert per-dimension source indices back to a linear index.
2585 int64_t srcLin = linearize(srcIdx, srcStrides);
2586 replacements.push_back(srcElems.getResult(srcLin));
2587 }
2588
2589 rewriter.replaceOp(toElementsOp, replacements);
2590 return success();
2591 }
2592};
2593
2594void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2595 MLIRContext *context) {
2596 results.add<ToElementsOfBroadcast>(context);
2597}
2598
2599//===----------------------------------------------------------------------===//
2600// FromElementsOp
2601//===----------------------------------------------------------------------===//
2602
2603/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2604///
2605/// Case #1: Input and output vectors are the same.
2606///
2607/// %0:3 = vector.to_elements %a : vector<3xf32>
2608/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2609/// user_op %1
2610///
2611/// becomes:
2612///
2613/// user_op %a
2614///
2615static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2616 OperandRange fromElemsOperands = fromElementsOp.getElements();
2617 if (fromElemsOperands.empty())
2618 return {};
2619
2620 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2621 if (!toElementsOp)
2622 return {};
2623
2624 if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2625 return {};
2626
2627 // Case #1: Input and output vectors are the same. Forward the input vector.
2628 Value toElementsInput = toElementsOp.getSource();
2629 if (fromElementsOp.getType() == toElementsInput.getType() &&
2630 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2631 return toElementsInput;
2632 }
2633
2634 // TODO: Support cases with different input and output shapes and different
2635 // number of elements.
2636
2637 return {};
2638}
2639
2640/// Fold vector.from_elements to a constant when all operands are constants.
2641/// Example:
2642/// %c1 = arith.constant 1 : i32
2643/// %c2 = arith.constant 2 : i32
2644/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2645/// =>
2646/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2647///
2648static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2649 ArrayRef<Attribute> elements) {
2650 // Check for null or poison attributes before any processing.
2651 if (llvm::any_of(elements, [](Attribute attr) {
2652 return !attr || isa<ub::PoisonAttrInterface>(attr);
2653 }))
2654 return {};
2655
2656 // DenseElementsAttr only supports int/index/float/complex types.
2657 auto destVecType = fromElementsOp.getDest().getType();
2658 auto destEltType = destVecType.getElementType();
2659 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2660 return {};
2661
2662 // Constant attributes might have a different type than the return type.
2663 // Convert them before creating the dense elements attribute.
2664 auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2665 return convertNumericAttr(attr, destEltType);
2666 });
2667
2668 return DenseElementsAttr::get(destVecType, convertedElements);
2669}
2670
2671OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2672 if (auto res = foldFromElementsToElements(*this))
2673 return res;
2674 if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2675 return res;
2676
2677 return {};
2678}
2679
2680/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2681/// same. Example:
2682/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2683/// =>
2684/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2685static LogicalResult
2686rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2687 PatternRewriter &rewriter) {
2688 if (!llvm::all_equal(fromElementsOp.getElements()))
2689 return failure();
2690 rewriter.replaceOpWithNewOp<BroadcastOp>(
2691 fromElementsOp, fromElementsOp.getType(),
2692 fromElementsOp.getElements().front());
2693 return success();
2694}
2695
2696/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2697/// on a single extract. Example:
2698/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2699/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2700/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2701///
2702/// becomes
2703/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2704/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2705///
2706/// The requirements for this to be valid are
2707///
2708/// i) The elements are extracted from the same vector (%source).
2709///
2710/// ii) The elements form a suffix of %source. Specifically, the number
2711/// of elements is the same as the product of the last N dimension sizes
2712/// of %source, for some N.
2713///
2714/// iii) The elements are extracted contiguously in ascending order.
2715
2716class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2717
2718 using Base::Base;
2719
2720 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2721 PatternRewriter &rewriter) const override {
2722
2723 // Handled by `rewriteFromElementsAsBroadcast`.
2724 if (fromElements.getType().getNumElements() == 1)
2725 return failure();
2726
2727 // The common source that all elements are extracted from, if one exists.
2729 // The position of the combined extract operation, if one is created.
2730 ArrayRef<int64_t> combinedPosition;
2731 // The expected index of extraction of the current element in the loop, if
2732 // elements are extracted contiguously in ascending order.
2733 SmallVector<int64_t> expectedPosition;
2734
2735 for (auto [insertIndex, element] :
2736 llvm::enumerate(fromElements.getElements())) {
2737
2738 // Check that the element is from a vector.extract operation.
2739 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2740 if (!extractOp) {
2741 return rewriter.notifyMatchFailure(fromElements,
2742 "element not from vector.extract");
2743 }
2744
2745 // Check condition (i) by checking that all elements have the same source
2746 // as the first element.
2747 if (insertIndex == 0) {
2748 source = extractOp.getSource();
2749 } else if (extractOp.getSource() != source) {
2750 return rewriter.notifyMatchFailure(fromElements,
2751 "element from different vector");
2752 }
2753
2754 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2755 int64_t rank = position.size();
2756 assert(rank == source.getType().getRank() &&
2757 "scalar extract must have full rank position");
2758
2759 // Check condition (ii) by checking that the position that the first
2760 // element is extracted from has sufficient trailing 0s. For example, in
2761 //
2762 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2763 // [...]
2764 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2765 //
2766 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2767 // elements, which is the number of elements of %n, so this is valid.
2768 if (insertIndex == 0) {
2769 const int64_t numElms = fromElements.getType().getNumElements();
2770 int64_t numSuffixElms = 1;
2771 int64_t index = rank;
2772 while (index > 0 && position[index - 1] == 0 &&
2773 numSuffixElms < numElms) {
2774 numSuffixElms *= source.getType().getDimSize(index - 1);
2775 --index;
2776 }
2777 if (numSuffixElms != numElms) {
2778 return rewriter.notifyMatchFailure(
2779 fromElements, "elements do not form a suffix of source");
2780 }
2781 expectedPosition = llvm::to_vector(position);
2782 combinedPosition = position.drop_back(rank - index);
2783 }
2784
2785 // Check condition (iii).
2786 else if (expectedPosition != position) {
2787 return rewriter.notifyMatchFailure(
2788 fromElements, "elements not in ascending order (static order)");
2789 }
2790 increment(expectedPosition, source.getType().getShape());
2791 }
2792
2793 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2794 fromElements.getLoc(), source, combinedPosition);
2795
2796 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2797 fromElements, fromElements.getType(), extracted);
2798
2799 return success();
2800 }
2801
2802 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2803 static void increment(MutableArrayRef<int64_t> indices,
2805 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2806 indices[dim] += 1;
2807 if (indices[dim] < shape[dim])
2808 break;
2809 indices[dim] = 0;
2810 }
2811 }
2812};
2813
2814void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2815 MLIRContext *context) {
2817 results.add<FromElementsToShapeCast>(context);
2818}
2819
2820//===----------------------------------------------------------------------===//
2821// BroadcastOp
2822//===----------------------------------------------------------------------===//
2823
2824void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2825 SetIntRangeFn setResultRanges) {
2826 setResultRanges(getResult(), argRanges.front());
2827}
2828
2829std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2830 return llvm::to_vector<4>(getResultVectorType().getShape());
2831}
2832
2833/// Return the dimensions of the result vector that were formerly ones in the
2834/// source tensor and thus correspond to "dim-1" broadcasting.
2835static llvm::SetVector<int64_t>
2837 ArrayRef<int64_t> dstShape) {
2838 int64_t rankDiff = dstShape.size() - srcShape.size();
2839 int64_t dstDim = rankDiff;
2841 for (auto [s1, s2] :
2842 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2843 if (s1 != s2) {
2844 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2845 res.insert(dstDim);
2846 }
2847 ++dstDim;
2848 }
2849 return res;
2850}
2851
2852llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2853 // Scalar broadcast is without any unit dim broadcast.
2854 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2855 if (!srcVectorType)
2856 return {};
2857 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2858 getResultVectorType().getShape());
2859}
2860
2861/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2862/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2863/// This requires (and asserts) that the broadcast is free of "dim-1"
2864/// broadcasting.
2865/// Since vector.broadcast only allows expanding leading dimensions, an extra
2866/// vector.transpose may be inserted to make the broadcast possible.
2867/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2868/// the helper will assert. This means:
2869/// 1. `dstShape` must not be empty.
2870/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2871/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2872// must match the `value` shape.
2873Value BroadcastOp::createOrFoldBroadcastOp(
2874 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2875 const llvm::SetVector<int64_t> &broadcastedDims) {
2876 assert(!dstShape.empty() && "unexpected empty dst shape");
2877
2878 // Well-formedness check.
2879 SmallVector<int64_t> checkShape;
2880 for (int i = 0, e = dstShape.size(); i < e; ++i) {
2881 if (broadcastedDims.contains(i))
2882 continue;
2883 checkShape.push_back(dstShape[i]);
2884 }
2885 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2886 "ill-formed broadcastedDims contains values not confined to "
2887 "destVectorShape");
2888
2889 Location loc = value.getLoc();
2890 Type elementType = getElementTypeOrSelf(value.getType());
2891 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2892 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2893
2894 // Step 2. If scalar -> dstShape broadcast, just do it.
2895 if (!srcVectorType) {
2896 assert(checkShape.empty() &&
2897 "ill-formed createOrFoldBroadcastOp arguments");
2898 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2899 }
2900
2901 assert(srcVectorType.getShape().equals(checkShape) &&
2902 "ill-formed createOrFoldBroadcastOp arguments");
2903
2904 // Step 3. Since vector.broadcast only allows creating leading dims,
2905 // vector -> dstShape broadcast may require a transpose.
2906 // Traverse the dims in order and construct:
2907 // 1. The leading entries of the broadcastShape that is guaranteed to be
2908 // achievable by a simple broadcast.
2909 // 2. The induced permutation for the subsequent vector.transpose that will
2910 // bring us from `broadcastShape` back to he desired `dstShape`.
2911 // If the induced permutation is not the identity, create a vector.transpose.
2912 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2913 broadcastShape.reserve(dstShape.size());
2914 // Consider the example:
2915 // srcShape = 2x4
2916 // dstShape = 1x2x3x4x5
2917 // broadcastedDims = [0, 2, 4]
2918 //
2919 // We want to build:
2920 // broadcastShape = 1x3x5x2x4
2921 // permutation = [0, 2, 4, 1, 3]
2922 // ---V--- -----V-----
2923 // leading broadcast part src shape part
2924 //
2925 // Note that the trailing dims of broadcastShape are exactly the srcShape
2926 // by construction.
2927 // nextSrcShapeDim is used to keep track of where in the permutation the
2928 // "src shape part" occurs.
2929 int64_t nextSrcShapeDim = broadcastedDims.size();
2930 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2931 if (broadcastedDims.contains(i)) {
2932 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2933 // bring it to the head of the broadcastShape.
2934 // It will need to be permuted back from `broadcastShape.size() - 1` into
2935 // position `i`.
2936 broadcastShape.push_back(dstShape[i]);
2937 permutation[i] = broadcastShape.size() - 1;
2938 } else {
2939 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2940 // shape and needs to be permuted into position `i`.
2941 // Don't touch `broadcastShape` here, the whole srcShape will be
2942 // appended after.
2943 permutation[i] = nextSrcShapeDim++;
2944 }
2945 }
2946 // 3.c. Append the srcShape.
2947 llvm::append_range(broadcastShape, srcVectorType.getShape());
2948
2949 // Ensure there are no "dim-1" broadcasts.
2950 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2951 .empty() &&
2952 "unexpected \"dim-1\" broadcast");
2953
2954 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2955 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2956 vector::BroadcastableToResult::Success &&
2957 "must be broadcastable");
2958 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2959 // Step 4. If we find any dimension that indeed needs to be permuted,
2960 // immediately return a new vector.transpose.
2961 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2962 if (permutation[i] != i)
2963 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2964 // Otherwise return res.
2965 return res;
2966}
2967
2969 Type srcType, VectorType dstVectorType,
2970 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2971 // Broadcast scalar to vector of the same element type.
2972 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2973 srcType == getElementTypeOrSelf(dstVectorType))
2975 // From now on, only vectors broadcast.
2976 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2977 if (!srcVectorType)
2979
2980 int64_t srcRank = srcVectorType.getRank();
2981 int64_t dstRank = dstVectorType.getRank();
2982 if (srcRank > dstRank)
2984 // Source has an exact match or singleton value for all trailing dimensions
2985 // (all leading dimensions are simply duplicated).
2986 int64_t lead = dstRank - srcRank;
2987 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2988 // Have mismatching dims (in the sense of vector.broadcast semantics) been
2989 // encountered?
2990 bool foundMismatchingDims = false;
2991
2992 // Check fixed-width dims.
2993 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2994 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2995 if (srcDim != 1 && srcDim != dstDim)
2996 foundMismatchingDims = true;
2997
2998 // Check scalable flags.
2999 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3000 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3001 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3002 // 1 -> [N] is fine, everything else should be rejected when mixing
3003 // fixed-width and scalable dims
3004 (srcDimScalableFlag != dstDimScalableFlag &&
3005 (srcDim != 1 || srcDimScalableFlag)))
3006 foundMismatchingDims = true;
3007
3008 if (foundMismatchingDims) {
3009 if (mismatchingDims != nullptr) {
3010 mismatchingDims->first.dim = srcDim;
3011 mismatchingDims->first.isScalable = srcDimScalableFlag;
3012
3013 mismatchingDims->second.dim = dstDim;
3014 mismatchingDims->second.isScalable = dstDimScalableFlag;
3015 }
3017 }
3018 }
3019
3021}
3022
3023LogicalResult BroadcastOp::verify() {
3024 std::pair<VectorDim, VectorDim> mismatchingDims;
3026 getSourceType(), getResultVectorType(), &mismatchingDims);
3028 return success();
3030 return emitOpError("source rank higher than destination rank");
3032 return emitOpError("dimension mismatch (")
3033 << (mismatchingDims.first.isScalable ? "[" : "")
3034 << mismatchingDims.first.dim
3035 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
3036 << (mismatchingDims.second.isScalable ? "[" : "")
3037 << mismatchingDims.second.dim
3038 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
3039 }
3041 return emitOpError("source type is not a vector");
3042 llvm_unreachable("unexpected vector.broadcast op error");
3043}
3044
3045// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
3046// with broadcast's result type and shape_cast only adds or removes ones in the
3047// leading dimensions.
3048static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
3049 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3050 if (!srcShapeCast)
3051 return failure();
3052
3053 VectorType srcType = srcShapeCast.getSourceVectorType();
3054 VectorType destType = broadcastOp.getResultVectorType();
3055 // Check type compatibility.
3056 if (vector::isBroadcastableTo(srcType, destType) !=
3058 return failure();
3059
3060 ArrayRef<int64_t> srcShape = srcType.getShape();
3061 ArrayRef<int64_t> shapecastShape =
3062 srcShapeCast.getResultVectorType().getShape();
3063 // Trailing dimensions should be the same if shape_cast only alters the
3064 // leading dimensions.
3065 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3066 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3067 shapecastShape.take_back(numTrailingDims)))
3068 return failure();
3069
3070 assert(all_of(srcShape.drop_back(numTrailingDims),
3071 [](int64_t E) { return E == 1; }) &&
3072 all_of(shapecastShape.drop_back(numTrailingDims),
3073 [](int64_t E) { return E == 1; }) &&
3074 "ill-formed shape_cast");
3075
3076 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3077 return success();
3078}
3079
3080OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3081 if (getSourceType() == getResultVectorType())
3082 return getSource();
3083 if (succeeded(foldBroadcastOfShapeCast(*this)))
3084 return getResult();
3085
3086 if (!adaptor.getSource())
3087 return {};
3088 auto vectorType = getResultVectorType();
3089 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3090 if (vectorType.getElementType() != attr.getType())
3091 return {};
3092 return DenseElementsAttr::get(vectorType, attr);
3093 }
3094 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3095 if (vectorType.getElementType() != attr.getType())
3096 return {};
3097 return DenseElementsAttr::get(vectorType, attr);
3098 }
3099 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3100 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3101 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
3102 return ub::PoisonAttr::get(getContext());
3103 return {};
3104}
3105
3106namespace {
3107
3108// Fold broadcast1(broadcast2(x)) into broadcast1(x).
3109struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3110 using Base::Base;
3111
3112 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3113 PatternRewriter &rewriter) const override {
3114 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3115 if (!srcBroadcast)
3116 return failure();
3117 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3118 broadcastOp.getResultVectorType(),
3119 srcBroadcast.getSource());
3120 return success();
3121 }
3122};
3123
3124/// Replace `vector.broadcast` with `vector.shape_cast`.
3125///
3126/// BEFORE:
3127/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3128/// AFTER:
3129/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3130///
3131/// The canonical form of vector operations that reshape vectors is shape_cast.
3132struct BroadcastToShapeCast final
3133 : public OpRewritePattern<vector::BroadcastOp> {
3134 using Base::Base;
3135 LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
3136 PatternRewriter &rewriter) const override {
3137
3138 auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
3139 if (!sourceType) {
3140 return rewriter.notifyMatchFailure(
3141 broadcast, "source is a scalar, shape_cast doesn't support scalar");
3142 }
3143
3144 VectorType outType = broadcast.getType();
3145 if (sourceType.getNumElements() != outType.getNumElements()) {
3146 return rewriter.notifyMatchFailure(
3147 broadcast, "broadcast to a greater number of elements");
3148 }
3149
3150 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
3151 broadcast.getSource());
3152 return success();
3153 }
3154};
3155} // namespace
3156
3157void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3158 MLIRContext *context) {
3159 results.add<BroadcastFolder, BroadcastToShapeCast>(context);
3160}
3161
3162//===----------------------------------------------------------------------===//
3163// ShuffleOp
3164//===----------------------------------------------------------------------===//
3165
3166LogicalResult ShuffleOp::verify() {
3167 VectorType resultType = getResultVectorType();
3168 VectorType v1Type = getV1VectorType();
3169 VectorType v2Type = getV2VectorType();
3170 // Verify ranks.
3171 int64_t resRank = resultType.getRank();
3172 int64_t v1Rank = v1Type.getRank();
3173 int64_t v2Rank = v2Type.getRank();
3174 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3175 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3176 if (!wellFormed0DCase && !wellFormedNDCase)
3177 return emitOpError("rank mismatch");
3178
3179 // Verify all but leading dimension sizes.
3180 for (int64_t r = 1; r < v1Rank; ++r) {
3181 int64_t resDim = resultType.getDimSize(r);
3182 int64_t v1Dim = v1Type.getDimSize(r);
3183 int64_t v2Dim = v2Type.getDimSize(r);
3184 if (resDim != v1Dim || v1Dim != v2Dim)
3185 return emitOpError("dimension mismatch");
3186 }
3187 // Verify mask length.
3188 ArrayRef<int64_t> mask = getMask();
3189 int64_t maskLength = mask.size();
3190 if (maskLength <= 0)
3191 return emitOpError("invalid mask length");
3192 if (maskLength != resultType.getDimSize(0))
3193 return emitOpError("mask length mismatch");
3194 // Verify all indices.
3195 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3196 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3197 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3198 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3199 return emitOpError("mask index #") << (idx + 1) << " out of range";
3200 }
3201 return success();
3202}
3203
3204LogicalResult
3205ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
3206 ShuffleOp::Adaptor adaptor,
3207 SmallVectorImpl<Type> &inferredReturnTypes) {
3208 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
3209 auto v1Rank = v1Type.getRank();
3210 // Construct resulting type: leading dimension matches mask
3211 // length, all trailing dimensions match the operands.
3212 SmallVector<int64_t, 4> shape;
3213 shape.reserve(v1Rank);
3214 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3215 // In the 0-D case there is no trailing shape to append.
3216 if (v1Rank > 0)
3217 llvm::append_range(shape, v1Type.getShape().drop_front());
3218 inferredReturnTypes.push_back(
3219 VectorType::get(shape, v1Type.getElementType()));
3220 return success();
3221}
3222
3223template <typename T>
3224static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3225 T expected = begin;
3226 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3227 return value == expected++;
3228 });
3229}
3230
3231OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3232 auto v1Type = getV1VectorType();
3233 auto v2Type = getV2VectorType();
3234
3235 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3236 "Vector shuffle does not support scalable vectors");
3237
3238 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3239 // but must be a canonicalization into a vector.broadcast.
3240 if (v1Type.getRank() == 0)
3241 return {};
3242
3243 // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3244 auto mask = getMask();
3245 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3246 return getV1();
3247 // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3248 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3249 return getV2();
3250
3251 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3252 if (!v1Attr || !v2Attr)
3253 return {};
3254
3255 // Fold shuffle poison, poison -> poison.
3256 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3257 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3258 if (isV1Poison && isV2Poison)
3259 return ub::PoisonAttr::get(getContext());
3260
3261 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
3262 // manipulation.
3263 if (v1Type.getRank() != 1)
3264 return {};
3265
3266 // Poison input attributes need special handling as they are not
3267 // DenseElementsAttr. If an index is poison, we select the first element of
3268 // the first non-poison input.
3269 SmallVector<Attribute> v1Elements, v2Elements;
3270 Attribute poisonElement;
3271 if (!isV2Poison) {
3272 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3273 if (!v2DenseAttr)
3274 return {};
3275 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3276 poisonElement = v2Elements[0];
3277 }
3278 if (!isV1Poison) {
3279 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3280 if (!v1DenseAttr)
3281 return {};
3282 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3283 poisonElement = v1Elements[0];
3284 }
3285
3286 SmallVector<Attribute> results;
3287 int64_t v1Size = v1Type.getDimSize(0);
3288 for (int64_t maskIdx : mask) {
3289 Attribute indexedElm;
3290 // TODO: Return a partial poison vector when supported by the UB dialect.
3291 if (maskIdx == ShuffleOp::kPoisonIndex) {
3292 indexedElm = poisonElement;
3293 } else {
3294 if (maskIdx < v1Size)
3295 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3296 else
3297 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3298 }
3299
3300 results.push_back(indexedElm);
3301 }
3302
3303 return DenseElementsAttr::get(getResultVectorType(), results);
3304}
3305
3306namespace {
3307
3308// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3309// to a broadcast.
3310struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3311 using Base::Base;
3312
3313 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3314 PatternRewriter &rewriter) const override {
3315 VectorType v1VectorType = shuffleOp.getV1VectorType();
3316 ArrayRef<int64_t> mask = shuffleOp.getMask();
3317 if (v1VectorType.getRank() > 0)
3318 return failure();
3319 if (mask.size() != 1)
3320 return failure();
3321 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3322 if (mask[0] == 0)
3323 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3324 shuffleOp.getV1());
3325 else
3326 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3327 shuffleOp.getV2());
3328 return success();
3329 }
3330};
3331
3332/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3333/// vector.broadcast with a scalar operand, return the scalar value that is
3334/// splatted. Otherwise return null.
3335///
3336/// Example:
3337///
3338/// scalar_source --> vector.broadcast --> value - return scalar_source
3339static Value getScalarSplatSource(Value value) {
3340 // Block argument:
3341 Operation *defOp = value.getDefiningOp();
3342 if (!defOp)
3343 return {};
3344
3345 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3346
3347 // Not broadcast (and not splat):
3348 if (!broadcast)
3349 return {};
3350
3351 // Broadcast of a vector:
3352 if (isa<VectorType>(broadcast.getSourceType()))
3353 return {};
3354
3355 // Broadcast of a scalar:
3356 return broadcast.getSource();
3357}
3358
3359/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3360class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3361public:
3362 using Base::Base;
3363
3364 LogicalResult matchAndRewrite(ShuffleOp op,
3365 PatternRewriter &rewriter) const override {
3366 Value splat = getScalarSplatSource(op.getV1());
3367 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3368 return failure();
3369
3370 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3371 return success();
3372 }
3373};
3374
3375/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3376/// vector.interleave.
3377class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3378public:
3379 using Base::Base;
3380
3381 LogicalResult matchAndRewrite(ShuffleOp op,
3382 PatternRewriter &rewriter) const override {
3383 VectorType resultType = op.getResultVectorType();
3384 if (resultType.isScalable())
3385 return rewriter.notifyMatchFailure(
3386 op, "ShuffleOp can't represent a scalable interleave");
3387
3388 if (resultType.getRank() != 1)
3389 return rewriter.notifyMatchFailure(
3390 op, "ShuffleOp can't represent an n-D interleave");
3391
3392 VectorType sourceType = op.getV1VectorType();
3393 if (sourceType != op.getV2VectorType() ||
3394 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3395 return rewriter.notifyMatchFailure(
3396 op, "ShuffleOp types don't match an interleave");
3397 }
3398
3399 ArrayRef<int64_t> shuffleMask = op.getMask();
3400 int64_t resultVectorSize = resultType.getNumElements();
3401 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3402 int64_t maskValueA = shuffleMask[i * 2];
3403 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3404 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3405 return rewriter.notifyMatchFailure(op,
3406 "ShuffleOp mask not interleaving");
3407 }
3408
3409 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3410 return success();
3411 }
3412};
3413
3414} // namespace
3415
3416void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3417 MLIRContext *context) {
3418 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3419 context);
3420}
3421
3422//===----------------------------------------------------------------------===//
3423// InsertOp
3424//===----------------------------------------------------------------------===//
3425
3426void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3427 SetIntRangeFn setResultRanges) {
3428 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3429}
3430
3431void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3432 Value source, Value dest) {
3433 auto vectorTy = cast<VectorType>(dest.getType());
3434 build(builder, result, source, dest,
3435 SmallVector<int64_t>(vectorTy.getRank(), 0));
3436}
3437
3438void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3439 Value source, Value dest, int64_t position) {
3440 build(builder, result, source, dest, ArrayRef<int64_t>{position});
3441}
3442
3443void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3444 Value source, Value dest, OpFoldResult position) {
3445 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3446}
3447
3448void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3449 Value source, Value dest,
3450 ArrayRef<int64_t> position) {
3451 SmallVector<OpFoldResult> posVals;
3452 posVals.reserve(position.size());
3453 llvm::transform(position, std::back_inserter(posVals),
3454 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3455 build(builder, result, source, dest, posVals);
3456}
3457
3458void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3459 Value source, Value dest,
3460 ArrayRef<OpFoldResult> position) {
3461 SmallVector<int64_t> staticPos;
3462 SmallVector<Value> dynamicPos;
3463 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3464 build(builder, result, source, dest, dynamicPos,
3465 builder.getDenseI64ArrayAttr(staticPos));
3466}
3467
3468LogicalResult InsertOp::verify() {
3469 if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3470 if (srcTy.getRank() == 0)
3471 return emitError(
3472 "expected a scalar instead of a 0-d vector as the source operand");
3473
3474 SmallVector<OpFoldResult> position = getMixedPosition();
3475 auto destVectorType = getDestVectorType();
3476 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3477 return emitOpError(
3478 "expected position attribute of rank no greater than dest vector rank");
3479 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3480 if (srcVectorType &&
3481 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3482 static_cast<unsigned>(destVectorType.getRank())))
3483 return emitOpError("expected position attribute rank + source rank to "
3484 "match dest vector rank");
3485 if (!srcVectorType &&
3486 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3487 return emitOpError(
3488 "expected position attribute rank to match the dest vector rank");
3489 for (auto [idx, pos] : llvm::enumerate(position)) {
3490 if (auto attr = dyn_cast<Attribute>(pos)) {
3491 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3492 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3493 destVectorType.getDimSize(idx))) {
3494 return emitOpError("expected position attribute #")
3495 << (idx + 1)
3496 << " to be a non-negative integer smaller than the "
3497 "corresponding "
3498 "dest vector dimension";
3499 }
3500 }
3501 }
3502 return success();
3503}
3504
3505// Calculate the linearized position of the continuous chunk of elements to
3506// insert, based on the shape of the value to insert and the positions to insert
3507// at.
3508static int64_t calculateInsertPosition(VectorType destTy,
3509 ArrayRef<int64_t> positions) {
3510 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3511 assert(positions.size() <= completePositions.size() &&
3512 "positions size must be less than or equal to destTy rank");
3513 copy(positions, completePositions.begin());
3514 return linearize(completePositions, computeStrides(destTy.getShape()));
3515}
3516
3517namespace {
3518
3519// If insertOp is only inserting unit dimensions it can be transformed to a
3520// broadcast.
3521class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3522public:
3523 using Base::Base;
3524
3525 LogicalResult matchAndRewrite(InsertOp insertOp,
3526 PatternRewriter &rewriter) const override {
3527 auto srcVecType =
3528 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3529 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3530 srcVecType.getNumElements())
3531 return failure();
3532 rewriter.replaceOpWithNewOp<BroadcastOp>(
3533 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3534 return success();
3535 }
3536};
3537
3538/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3539class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3540public:
3541 using Base::Base;
3542
3543 LogicalResult matchAndRewrite(InsertOp op,
3544 PatternRewriter &rewriter) const override {
3545
3546 Value splat = getScalarSplatSource(op.getValueToStore());
3547 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3548 return failure();
3549
3550 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3551 return success();
3552 }
3553};
3554
3555/// Pattern to optimize a chain of insertions.
3556///
3557/// This pattern identifies chains of vector.insert operations that:
3558/// 1. Only insert values at static positions.
3559/// 2. Completely initialize all elements in the resulting vector.
3560/// 3. All intermediate insert operations have only one use.
3561///
3562/// When these conditions are met, the entire chain can be replaced with a
3563/// single vector.from_elements operation.
3564///
3565/// To keep this pattern simple, and avoid spending too much time on matching
3566/// fragmented insert chains, this pattern only considers the last insert op in
3567/// the chain.
3568///
3569/// Example transformation:
3570/// %poison = ub.poison : vector<2xi32>
3571/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3572/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3573/// ->
3574/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3575class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3576public:
3577 using Base::Base;
3578 LogicalResult matchAndRewrite(InsertOp op,
3579 PatternRewriter &rewriter) const override {
3580
3581 VectorType destTy = op.getDestVectorType();
3582 if (destTy.isScalable())
3583 return failure();
3584 // Ensure this is the trailing vector.insert op in a chain of inserts.
3585 for (Operation *user : op.getResult().getUsers())
3586 if (auto insertOp = dyn_cast<InsertOp>(user))
3587 if (insertOp.getDest() == op.getResult())
3588 return failure();
3589
3590 InsertOp currentOp = op;
3591 SmallVector<InsertOp> chainInsertOps;
3592 while (currentOp) {
3593 // Check cond 1: Dynamic position is not supported.
3594 if (currentOp.hasDynamicPosition())
3595 return failure();
3596
3597 chainInsertOps.push_back(currentOp);
3598 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3599 // Check cond 3: Intermediate inserts have only one use to avoid an
3600 // explosion of vectors.
3601 if (currentOp && !currentOp->hasOneUse())
3602 return failure();
3603 }
3604
3605 int64_t vectorSize = destTy.getNumElements();
3606 int64_t initializedCount = 0;
3607 SmallVector<bool> initializedDestIdxs(vectorSize, false);
3608 SmallVector<int64_t> pendingInsertPos;
3609 SmallVector<int64_t> pendingInsertSize;
3610 SmallVector<Value> pendingInsertValues;
3611
3612 for (auto insertOp : chainInsertOps) {
3613 // This pattern can do nothing with poison index.
3614 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3615 return failure();
3616
3617 // Calculate the linearized position for inserting elements.
3618 int64_t insertBeginPosition =
3619 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3620
3621 // The valueToStore operand may be a vector or a scalar. Need to handle
3622 // both cases.
3623 int64_t insertSize = 1;
3624 if (auto srcVectorType =
3625 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3626 insertSize = srcVectorType.getNumElements();
3627
3628 assert(insertBeginPosition + insertSize <= vectorSize &&
3629 "insert would overflow the vector");
3630
3631 for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3632 insertBeginPosition + insertSize)) {
3633 if (initializedDestIdxs[index])
3634 continue;
3635 initializedDestIdxs[index] = true;
3636 ++initializedCount;
3637 }
3638
3639 // Defer the creation of ops before we can make sure the pattern can
3640 // succeed.
3641 pendingInsertPos.push_back(insertBeginPosition);
3642 pendingInsertSize.push_back(insertSize);
3643 pendingInsertValues.push_back(insertOp.getValueToStore());
3644
3645 if (initializedCount == vectorSize)
3646 break;
3647 }
3648
3649 // Check cond 2: all positions must be initialized.
3650 if (initializedCount != vectorSize)
3651 return failure();
3652
3653 SmallVector<Value> elements(vectorSize);
3654 for (auto [insertBeginPosition, insertSize, valueToStore] :
3655 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3656 pendingInsertValues))) {
3657 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3658
3659 if (!srcVectorType) {
3660 elements[insertBeginPosition] = valueToStore;
3661 continue;
3662 }
3663
3664 SmallVector<Type> elementToInsertTypes(insertSize,
3665 srcVectorType.getElementType());
3666 // Get all elements from the vector in row-major order.
3667 auto elementsToInsert = vector::ToElementsOp::create(
3668 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3669 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3670 elements[insertBeginPosition + linearIdx] =
3671 elementsToInsert.getResult(linearIdx);
3672 }
3673 }
3674
3675 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3676 return success();
3677 }
3678};
3679
3680} // namespace
3681
3682static Attribute
3684 Attribute dstAttr,
3685 int64_t maxVectorSizeFoldThreshold) {
3686 if (insertOp.hasDynamicPosition())
3687 return {};
3688
3689 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3690 if (!denseDst)
3691 return {};
3692
3693 if (!srcAttr) {
3694 return {};
3695 }
3696
3697 VectorType destTy = insertOp.getDestVectorType();
3698 if (destTy.isScalable())
3699 return {};
3700
3701 // Make sure we do not create too many large constants.
3702 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3703 !insertOp->hasOneUse())
3704 return {};
3705
3706 // Calculate the linearized position for inserting elements.
3707 int64_t insertBeginPosition =
3708 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3709 SmallVector<Attribute> insertedValues;
3710 Type destEltType = destTy.getElementType();
3711
3712 /// Converts attribute to the expected type if there's
3713 /// a mismatch.
3714 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3715 for (auto value : denseSource.getValues<Attribute>())
3716 insertedValues.push_back(convertNumericAttr(value, destEltType));
3717 } else {
3718 insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3719 }
3720
3721 auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3722 copy(insertedValues, allValues.begin() + insertBeginPosition);
3723 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3724
3725 return newAttr;
3726}
3727
3728/// Folder to replace the `dest` operand of the insert op with the root dest of
3729/// the insert op use chain.
3730static Value foldInsertUseChain(InsertOp insertOp) {
3731 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3732 if (!destInsert)
3733 return {};
3734
3735 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3736 return {};
3737
3738 insertOp.setOperand(1, destInsert.getDest());
3739 return insertOp.getResult();
3740}
3741
3742void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3743 MLIRContext *context) {
3744 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3745 InsertChainFullyInitialized>(context);
3746}
3747
3748OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3749 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3750 // unless the source vector constant has a single use.
3751 constexpr int64_t vectorSizeFoldThreshold = 256;
3752 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3753 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3754 // (type mismatch).
3755 if (getNumIndices() == 0 && getValueToStoreType() == getType())
3756 return getValueToStore();
3757 // Fold `arith.constant` indices into the `vector.insert` operation.
3758 // Do not stop here as this fold may enable subsequent folds that require
3759 // constant indices.
3760 SmallVector<Value> operands = {getValueToStore(), getDest()};
3761 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3762
3763 if (auto res = foldInsertUseChain(*this))
3764 return res;
3765 if (auto res = foldPoisonIndexInsertExtractOp(
3766 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3767 return res;
3768 if (auto res = foldDenseElementsAttrDestInsertOp(
3769 *this, adaptor.getValueToStore(), adaptor.getDest(),
3770 vectorSizeFoldThreshold)) {
3771 return res;
3772 }
3773
3774 return inplaceFolded;
3775}
3776
3777//===----------------------------------------------------------------------===//
3778// InsertStridedSliceOp
3779//===----------------------------------------------------------------------===//
3780
3781void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3782 Value source, Value dest,
3783 ArrayRef<int64_t> offsets,
3784 ArrayRef<int64_t> strides) {
3785 result.addOperands({source, dest});
3786 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3787 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3788 result.addTypes(dest.getType());
3789 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3790 offsetsAttr);
3791 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3792 stridesAttr);
3793}
3794
3795// TODO: Should be moved to Tablegen ConfinedAttr attributes.
3796template <typename OpType>
3797static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3798 ArrayAttr arrayAttr,
3800 StringRef attrName) {
3801 if (arrayAttr.size() > shape.size())
3802 return op.emitOpError("expected ")
3803 << attrName << " attribute of rank no greater than vector rank";
3804 return success();
3805}
3806
3807// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3808// interval. If `halfOpen` is true then the admissible interval is [min, max).
3809// Otherwise, the admissible interval is [min, max].
3810template <typename OpType>
3811static LogicalResult
3813 int64_t max, StringRef attrName,
3814 bool halfOpen = true) {
3815 for (auto attr : arrayAttr) {
3816 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3817 auto upper = max;
3818 if (!halfOpen)
3819 upper += 1;
3820 if (val < min || val >= upper)
3821 return op.emitOpError("expected ") << attrName << " to be confined to ["
3822 << min << ", " << upper << ")";
3823 }
3824 return success();
3825}
3826
3827// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3828// interval. If `halfOpen` is true then the admissible interval is [min, max).
3829// Otherwise, the admissible interval is [min, max].
3830template <typename OpType>
3831static LogicalResult
3833 ArrayRef<int64_t> shape, StringRef attrName,
3834 bool halfOpen = true, int64_t min = 0) {
3835 for (auto [index, attrDimPair] :
3836 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3837 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3838 int64_t max = std::get<1>(attrDimPair);
3839 if (!halfOpen)
3840 max += 1;
3841 if (val < min || val >= max)
3842 return op.emitOpError("expected ")
3843 << attrName << " dimension " << index << " to be confined to ["
3844 << min << ", " << max << ")";
3845 }
3846 return success();
3847}
3848
3849// Returns true if, for all indices i = 0..shape.size()-1, val is in the
3850// [min, max} interval:
3851// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3852// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3853// the admissible interval is [min, max].
3854template <typename OpType>
3856 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3857 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3858 bool halfOpen = true, int64_t min = 1) {
3859 assert(arrayAttr1.size() <= shape.size());
3860 assert(arrayAttr2.size() <= shape.size());
3861 for (auto [index, it] :
3862 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3863 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3864 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3865 int64_t max = std::get<2>(it);
3866 if (!halfOpen)
3867 max += 1;
3868 if (val1 + val2 < 0 || val1 + val2 >= max)
3869 return op.emitOpError("expected sum(")
3870 << attrName1 << ", " << attrName2 << ") dimension " << index
3871 << " to be confined to [" << min << ", " << max << ")";
3872 }
3873 return success();
3874}
3875
3877 MLIRContext *context) {
3878 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3879 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3880 });
3881 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3882}
3883
3884LogicalResult InsertStridedSliceOp::verify() {
3885 auto sourceVectorType = getSourceVectorType();
3886 auto destVectorType = getDestVectorType();
3887 auto offsets = getOffsetsAttr();
3888 auto strides = getStridesAttr();
3889 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3890 return emitOpError(
3891 "expected offsets of same size as destination vector rank");
3892 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3893 return emitOpError("expected strides of same size as source vector rank");
3894 if (sourceVectorType.getRank() > destVectorType.getRank())
3895 return emitOpError(
3896 "expected source rank to be no greater than destination rank");
3897
3898 auto sourceShape = sourceVectorType.getShape();
3899 auto destShape = destVectorType.getShape();
3900 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3901 destShape.size() - sourceShape.size(), 0);
3902 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3903 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3904 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3905 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3906 offName)) ||
3907 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3908 /*max=*/1, stridesName,
3909 /*halfOpen=*/false)) ||
3911 *this, offsets,
3912 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3913 offName, "source vector shape",
3914 /*halfOpen=*/false, /*min=*/1)))
3915 return failure();
3916
3917 unsigned rankDiff = destShape.size() - sourceShape.size();
3918 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3919 if (sourceVectorType.getScalableDims()[idx] !=
3920 destVectorType.getScalableDims()[idx + rankDiff]) {
3921 return emitOpError("mismatching scalable flags (at source vector idx=")
3922 << idx << ")";
3923 }
3924 if (sourceVectorType.getScalableDims()[idx]) {
3925 auto sourceSize = sourceShape[idx];
3926 auto destSize = destShape[idx + rankDiff];
3927 if (sourceSize != destSize) {
3928 return emitOpError("expected size at idx=")
3929 << idx
3930 << (" to match the corresponding base size from the input "
3931 "vector (")
3932 << sourceSize << (" vs ") << destSize << (")");
3933 }
3934 }
3935 }
3936
3937 return success();
3938}
3939
3940namespace {
3941/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
3942class FoldInsertStridedSliceSplat final
3943 : public OpRewritePattern<InsertStridedSliceOp> {
3944public:
3945 using Base::Base;
3946
3947 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3948 PatternRewriter &rewriter) const override {
3949
3950 auto dst = insertStridedSliceOp.getDest();
3951 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3952 if (!splat || getScalarSplatSource(dst) != splat)
3953 return failure();
3954
3955 rewriter.replaceOp(insertStridedSliceOp, dst);
3956 return success();
3957 }
3958};
3959
3960/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3961/// to dst.
3962class FoldInsertStridedSliceOfExtract final
3963 : public OpRewritePattern<InsertStridedSliceOp> {
3964public:
3965 using Base::Base;
3966
3967 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3968 PatternRewriter &rewriter) const override {
3969 auto extractStridedSliceOp =
3970 insertStridedSliceOp.getValueToStore()
3971 .getDefiningOp<vector::ExtractStridedSliceOp>();
3972
3973 if (!extractStridedSliceOp)
3974 return failure();
3975
3976 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3977 return failure();
3978
3979 // Check if have the same strides and offsets.
3980 if (extractStridedSliceOp.getStrides() !=
3981 insertStridedSliceOp.getStrides() ||
3982 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3983 return failure();
3984
3985 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3986 return success();
3987 }
3988};
3989
3990// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3991// ConstantOp.
3992class InsertStridedSliceConstantFolder final
3993 : public OpRewritePattern<InsertStridedSliceOp> {
3994public:
3995 using Base::Base;
3996
3997 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3998 // unless the source vector constant has a single use.
3999 static constexpr int64_t vectorSizeFoldThreshold = 256;
4000
4001 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4002 PatternRewriter &rewriter) const override {
4003 // Return if 'InsertOp' operand is not defined by a compatible vector
4004 // ConstantOp.
4005 TypedValue<VectorType> destVector = op.getDest();
4006 Attribute vectorDestCst;
4007 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
4008 return failure();
4009
4010 VectorType destTy = destVector.getType();
4011 if (destTy.isScalable())
4012 return failure();
4013
4014 // Make sure we do not create too many large constants.
4015 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4016 !destVector.hasOneUse())
4017 return failure();
4018
4019 TypedValue<VectorType> sourceValue = op.getValueToStore();
4020 Attribute sourceCst;
4021 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
4022 return failure();
4023
4024 // TODO: Support poison.
4025 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
4026 return failure();
4027
4028 // TODO: Handle non-unit strides when they become available.
4029 if (op.hasNonUnitStrides())
4030 return failure();
4031
4032 VectorType sliceVecTy = sourceValue.getType();
4033 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4034 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4035 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
4036 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
4037
4038 // Calcualte the destination element indices by enumerating all slice
4039 // positions within the destination and linearizing them. The enumeration
4040 // order is lexicographic which yields a sequence of monotonically
4041 // increasing linearized position indices.
4042 // Because the destination may have higher dimensionality then the slice,
4043 // we keep track of two overlapping sets of positions and offsets.
4044 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4045 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4046 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4047 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4048 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4049 MutableArrayRef<int64_t> currSlicePosition(
4050 currDestPosition.begin() + rankDifference, currDestPosition.end());
4051 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4052 offsets.end());
4053 do {
4054 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
4055 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
4056 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4057 "Invalid slice element");
4058 newValues[linearizedPosition] = *sliceValuesIt;
4059 ++sliceValuesIt;
4060 } while (succeeded(
4061 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
4062
4063 auto newAttr = DenseElementsAttr::get(destTy, newValues);
4064 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
4065 return success();
4066 }
4067};
4068
4069} // namespace
4070
4071void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4072 RewritePatternSet &results, MLIRContext *context) {
4073 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4074 InsertStridedSliceConstantFolder>(context);
4075}
4076
4077OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4078 if (getSourceVectorType() == getDestVectorType())
4079 return getValueToStore();
4080 return {};
4081}
4082
4083//===----------------------------------------------------------------------===//
4084// OuterProductOp
4085//===----------------------------------------------------------------------===//
4086
4087/// Build an op without mask, use the type of `acc` as the return type.
4088void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4089 Value lhs, Value rhs, Value acc) {
4090 result.addOperands({lhs, rhs, acc});
4091 result.addTypes(acc.getType());
4092}
4093
4094void OuterProductOp::print(OpAsmPrinter &p) {
4095 p << " " << getLhs() << ", " << getRhs();
4096 if (getAcc()) {
4097 p << ", " << getAcc();
4098 p.printOptionalAttrDict((*this)->getAttrs());
4099 }
4100 p << " : " << getLhs().getType() << ", " << getRhs().getType();
4101}
4102
4103ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4104 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4105 Type tLHS, tRHS;
4106 if (parser.parseOperandList(operandsInfo) ||
4107 parser.parseOptionalAttrDict(result.attributes) ||
4108 parser.parseColonType(tLHS) || parser.parseComma() ||
4109 parser.parseType(tRHS))
4110 return failure();
4111 if (operandsInfo.size() < 2)
4112 return parser.emitError(parser.getNameLoc(),
4113 "expected at least 2 operands");
4114 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4115 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4116 if (!vLHS)
4117 return parser.emitError(parser.getNameLoc(),
4118 "expected vector type for operand #1");
4119
4120 VectorType resType;
4121 if (vRHS) {
4122 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4123 vRHS.getScalableDims()[0]};
4124 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4125 vLHS.getElementType(), scalableDimsRes);
4126 } else {
4127 // Scalar RHS operand
4128 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4129 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4130 scalableDimsRes);
4131 }
4132
4133 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4134 result.attributes.append(
4135 OuterProductOp::getKindAttrName(result.name),
4136 CombiningKindAttr::get(result.getContext(),
4137 OuterProductOp::getDefaultKind()));
4138 }
4139
4140 return failure(
4141 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4142 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4143 (operandsInfo.size() > 2 &&
4144 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4145 parser.addTypeToList(resType, result.types));
4146}
4147
4148LogicalResult OuterProductOp::verify() {
4149 Type tRHS = getOperandTypeRHS();
4150 VectorType vLHS = getOperandVectorTypeLHS(),
4151 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4152 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4153
4154 if (vLHS.getRank() != 1)
4155 return emitOpError("expected 1-d vector for operand #1");
4156
4157 if (vRHS) {
4158 // Proper OUTER operation.
4159 if (vRHS.getRank() != 1)
4160 return emitOpError("expected 1-d vector for operand #2");
4161 if (vRES.getRank() != 2)
4162 return emitOpError("expected 2-d vector result");
4163 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4164 return emitOpError("expected #1 operand dim to match result dim #1");
4165 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4166 return emitOpError("expected #2 operand dim to match result dim #2");
4167 if (vLHS.isScalable() && !vRHS.isScalable()) {
4168 // This restriction reflects what's currently supported in terms of
4169 // scalable vectors. However, we could relax this if there's a use case.
4170 return emitOpError(
4171 "expected either both or only #2 operand dim to be scalable");
4172 }
4173 } else {
4174 // An AXPY operation.
4175 if (vRES.getRank() != 1)
4176 return emitOpError("expected 1-d vector result");
4177 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4178 return emitOpError("expected #1 operand dim to match result dim #1");
4179 }
4180
4181 if (vACC && vACC != vRES)
4182 return emitOpError("expected operand #3 of same type as result type");
4183
4184 if (!getKindAttr()) {
4185 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
4186 "'vector.kind<add>')");
4187 }
4188
4189 // Verify supported combining kind.
4190 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4191 return emitOpError("unsupported outerproduct type");
4192
4193 return success();
4194}
4195
4196// MaskableOpInterface methods.
4197
4198/// Returns the mask type expected by this operation. Mostly used for
4199/// verification purposes. It requires the operation to be vectorized."
4200Type OuterProductOp::getExpectedMaskType() {
4201 auto vecType = this->getResultVectorType();
4202 return VectorType::get(vecType.getShape(),
4203 IntegerType::get(vecType.getContext(), /*width=*/1),
4204 vecType.getScalableDims());
4205}
4206
4207//===----------------------------------------------------------------------===//
4208// ExtractStridedSliceOp
4209//===----------------------------------------------------------------------===//
4210
4211// Inference works as follows:
4212// 1. Add 'sizes' from prefix of dims in 'offsets'.
4213// 2. Add sizes from 'vectorType' for remaining dims.
4214// Scalable flags are inherited from 'vectorType'.
4215static Type inferStridedSliceOpResultType(VectorType vectorType,
4216 ArrayAttr offsets, ArrayAttr sizes,
4217 ArrayAttr strides) {
4218 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4220 shape.reserve(vectorType.getRank());
4221 unsigned idx = 0;
4222 for (unsigned e = offsets.size(); idx < e; ++idx)
4223 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4224 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4225 shape.push_back(vectorType.getShape()[idx]);
4226
4227 return VectorType::get(shape, vectorType.getElementType(),
4228 vectorType.getScalableDims());
4229}
4230
4231void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4232 Value source, ArrayRef<int64_t> offsets,
4233 ArrayRef<int64_t> sizes,
4234 ArrayRef<int64_t> strides) {
4235 result.addOperands(source);
4236 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4237 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4238 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4239 result.addTypes(
4240 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4241 offsetsAttr, sizesAttr, stridesAttr));
4242 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4243 offsetsAttr);
4244 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4245 sizesAttr);
4246 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4247 stridesAttr);
4248}
4249
4250LogicalResult ExtractStridedSliceOp::verify() {
4251 auto type = getSourceVectorType();
4252 auto offsets = getOffsetsAttr();
4253 auto sizes = getSizesAttr();
4254 auto strides = getStridesAttr();
4255 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4256 return emitOpError(
4257 "expected offsets, sizes and strides attributes of same size");
4258
4259 auto shape = type.getShape();
4260 auto offName = getOffsetsAttrName();
4261 auto sizesName = getSizesAttrName();
4262 auto stridesName = getStridesAttrName();
4263 if (failed(
4264 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4265 failed(
4266 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4267 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4268 stridesName)) ||
4269 failed(
4270 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4271 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4272 /*halfOpen=*/false,
4273 /*min=*/1)) ||
4274 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4275 /*max=*/1, stridesName,
4276 /*halfOpen=*/false)) ||
4277 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4278 shape, offName, sizesName,
4279 /*halfOpen=*/false)))
4280 return failure();
4281
4282 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4283 offsets, sizes, strides);
4284 if (getResult().getType() != resultType)
4285 return emitOpError("expected result type to be ") << resultType;
4286
4287 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4288 if (type.getScalableDims()[idx]) {
4289 auto inputDim = type.getShape()[idx];
4290 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4291 if (inputDim != inputSize)
4292 return emitOpError("expected size at idx=")
4293 << idx
4294 << (" to match the corresponding base size from the input "
4295 "vector (")
4296 << inputSize << (" vs ") << inputDim << (")");
4297 }
4298 }
4299
4300 return success();
4301}
4302
4303// When the source of ExtractStrided comes from a chain of InsertStrided ops try
4304// to use the source of the InsertStrided ops if we can detect that the
4305// extracted vector is a subset of one of the vector inserted.
4306static LogicalResult
4307foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4308 // Helper to extract integer out of ArrayAttr.
4309 auto getElement = [](ArrayAttr array, int idx) {
4310 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4311 };
4312 ArrayAttr extractOffsets = op.getOffsets();
4313 ArrayAttr extractStrides = op.getStrides();
4314 ArrayAttr extractSizes = op.getSizes();
4315 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4316 while (insertOp) {
4317 if (op.getSourceVectorType().getRank() !=
4318 insertOp.getSourceVectorType().getRank())
4319 return failure();
4320 ArrayAttr insertOffsets = insertOp.getOffsets();
4321 ArrayAttr insertStrides = insertOp.getStrides();
4322 // If the rank of extract is greater than the rank of insert, we are likely
4323 // extracting a partial chunk of the vector inserted.
4324 if (extractOffsets.size() > insertOffsets.size())
4325 return failure();
4326 bool patialoverlap = false;
4327 bool disjoint = false;
4328 SmallVector<int64_t, 4> offsetDiffs;
4329 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4330 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4331 return failure();
4332 int64_t start = getElement(insertOffsets, dim);
4333 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4334 int64_t offset = getElement(extractOffsets, dim);
4335 int64_t size = getElement(extractSizes, dim);
4336 // Check if the start of the extract offset is in the interval inserted.
4337 if (start <= offset && offset < end) {
4338 // If the extract interval overlaps but is not fully included we may
4339 // have a partial overlap that will prevent any folding.
4340 if (offset + size > end)
4341 patialoverlap = true;
4342 offsetDiffs.push_back(offset - start);
4343 continue;
4344 }
4345 disjoint = true;
4346 break;
4347 }
4348 // The extract element chunk is a subset of the insert element.
4349 if (!disjoint && !patialoverlap) {
4350 op.setOperand(insertOp.getValueToStore());
4351 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4352 OpBuilder b(op.getContext());
4353 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4354 return success();
4355 }
4356 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4357 // in the insert chain.
4358 if (disjoint)
4359 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4360 else {
4361 // The extracted vector partially overlap the inserted vector, we cannot
4362 // fold.
4363 return failure();
4364 }
4365 }
4366 return failure();
4367}
4368
4369// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4370static OpFoldResult
4372 Attribute foldInput) {
4373
4374 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4375 if (!dense)
4376 return {};
4377
4378 // TODO: Handle non-unit strides when they become available.
4379 if (op.hasNonUnitStrides())
4380 return {};
4381
4382 VectorType sourceVecTy = op.getSourceVectorType();
4383 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4384 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4385
4386 VectorType sliceVecTy = op.getType();
4387 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4388 int64_t rank = sliceVecTy.getRank();
4389
4390 // Expand offsets and sizes to match the vector rank.
4391 SmallVector<int64_t, 4> offsets(rank, 0);
4392 copy(getI64SubArray(op.getOffsets()), offsets.begin());
4393
4394 SmallVector<int64_t, 4> sizes(sourceShape);
4395 copy(getI64SubArray(op.getSizes()), sizes.begin());
4396
4397 // Calculate the slice elements by enumerating all slice positions and
4398 // linearizing them. The enumeration order is lexicographic which yields a
4399 // sequence of monotonically increasing linearized position indices.
4400 const auto denseValuesBegin = dense.value_begin<Attribute>();
4401 SmallVector<Attribute> sliceValues;
4402 sliceValues.reserve(sliceVecTy.getNumElements());
4403 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4404 do {
4405 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4406 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4407 "Invalid index");
4408 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4409 } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4410
4411 assert(static_cast<int64_t>(sliceValues.size()) ==
4412 sliceVecTy.getNumElements() &&
4413 "Invalid number of slice elements");
4414 return DenseElementsAttr::get(sliceVecTy, sliceValues);
4415}
4416
4417OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4418 if (getSourceVectorType() == getResult().getType())
4419 return getSource();
4420 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4421 return getResult();
4422
4423 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4424 if (auto splat =
4425 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4426 return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4427
4428 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4429 return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4430}
4431
4432void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4433 populateFromInt64AttrArray(getOffsets(), results);
4434}
4435
4436namespace {
4437
4438// Pattern to rewrite nested ExtractStridedSliceOp into a single one.
4439//
4440// Example:
4441//
4442// %0 = vector.extract_strided_slice %arg0
4443// {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]}
4444// : vector<4x8x16xf32> to vector<3x4x16xf32>
4445// %1 = vector.extract_strided_slice %0
4446// {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
4447// : vector<3x4x16xf32> to vector<2x2x16xf32>
4448//
4449// to
4450//
4451// %1 = vector.extract_strided_slice %arg0
4452// {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
4453// : vector<4x8x16xf32> to vector<2x2x16xf32>
4454class StridedSliceFolder final
4455 : public OpRewritePattern<ExtractStridedSliceOp> {
4456public:
4457 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4458
4459 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4460 PatternRewriter &rewriter) const override {
4461 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4462 if (!firstOp)
4463 return failure();
4464
4465 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4466 return failure();
4467
4468 SmallVector<int64_t> firstOffsets = getI64SubArray(firstOp.getOffsets());
4469 SmallVector<int64_t> firstSizes = getI64SubArray(firstOp.getSizes());
4470 SmallVector<int64_t> secondOffsets = getI64SubArray(secondOp.getOffsets());
4471 SmallVector<int64_t> secondSizes = getI64SubArray(secondOp.getSizes());
4472
4473 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4474 SmallVector<int64_t> combinedOffsets(newRank, 0);
4475 SmallVector<int64_t> combinedSizes(newRank);
4476 ArrayRef<int64_t> firstSourceShape =
4477 firstOp.getSourceVectorType().getShape();
4478 for (unsigned i = 0; i < newRank; ++i) {
4479 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4480 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4481 combinedOffsets[i] = off1 + off2;
4482
4483 if (i < secondSizes.size()) {
4484 combinedSizes[i] = secondSizes[i];
4485 } else if (i < firstSizes.size()) {
4486 combinedSizes[i] = firstSizes[i];
4487 } else {
4488 combinedSizes[i] = firstSourceShape[i];
4489 }
4490 }
4491
4492 SmallVector<int64_t> combinedStrides(newRank, 1);
4493 rewriter.replaceOpWithNewOp<ExtractStridedSliceOp>(
4494 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4495 combinedStrides);
4496 return success();
4497 }
4498};
4499
4500// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4501// CreateMaskOp.
4502//
4503// Example:
4504//
4505// %mask = vector.create_mask %ub : vector<16xi1>
4506// %slice = vector.extract_strided_slice [%offset] [8] [1]
4507//
4508// to
4509//
4510// %new_ub = arith.subi %ub, %offset
4511// %mask = vector.create_mask %new_ub : vector<8xi1>
4512class StridedSliceCreateMaskFolder final
4513 : public OpRewritePattern<ExtractStridedSliceOp> {
4514 using Base::Base;
4515
4516public:
4517 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4518 PatternRewriter &rewriter) const override {
4519 Location loc = extractStridedSliceOp.getLoc();
4520 // Return if 'extractStridedSliceOp' operand is not defined by a
4521 // CreateMaskOp.
4522 auto createMaskOp =
4523 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4524 if (!createMaskOp)
4525 return failure();
4526 // Return if 'extractStridedSliceOp' has non-unit strides.
4527 if (extractStridedSliceOp.hasNonUnitStrides())
4528 return failure();
4529 // Gather constant mask dimension sizes.
4530 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4531 // Gather strided slice offsets and sizes.
4532 SmallVector<int64_t> sliceOffsets;
4533 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4534 sliceOffsets);
4535 SmallVector<int64_t> sliceSizes;
4536 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4537
4538 // Compute slice of vector mask region.
4539 SmallVector<Value> sliceMaskDimSizes;
4540 sliceMaskDimSizes.reserve(maskDimSizes.size());
4541 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4542 // only iterate on the leading dim sizes. The tail accounts for the
4543 // remaining dim sizes.
4544 for (auto [maskDimSize, sliceOffset, sliceSize] :
4545 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4546 // No need to clamp on min/max values, because create_mask has clamping
4547 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4548 // greater than the vector dim size.
4549 IntegerAttr offsetAttr =
4550 rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4551 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4552 Value sliceMaskDimSize =
4553 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4554 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4555 }
4556 // Add unchanged dimensions.
4557 llvm::append_range(
4558 sliceMaskDimSizes,
4559 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4560 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4561 // region.
4562 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4563 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4564 sliceMaskDimSizes);
4565 return success();
4566 }
4567};
4568
4569// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4570// ConstantMaskOp.
4571class StridedSliceConstantMaskFolder final
4572 : public OpRewritePattern<ExtractStridedSliceOp> {
4573public:
4574 using Base::Base;
4575
4576 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4577 PatternRewriter &rewriter) const override {
4578 // Return if 'extractStridedSliceOp' operand is not defined by a
4579 // ConstantMaskOp.
4580 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4581 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4582 if (!constantMaskOp)
4583 return failure();
4584 // Return if 'extractStridedSliceOp' has non-unit strides.
4585 if (extractStridedSliceOp.hasNonUnitStrides())
4586 return failure();
4587 // Gather constant mask dimension sizes.
4588 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4589 // Gather strided slice offsets and sizes.
4590 SmallVector<int64_t> sliceOffsets;
4591 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4592 sliceOffsets);
4593 SmallVector<int64_t> sliceSizes;
4594 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4595
4596 // Compute slice of vector mask region.
4597 SmallVector<int64_t> sliceMaskDimSizes;
4598 sliceMaskDimSizes.reserve(maskDimSizes.size());
4599 for (auto [maskDimSize, sliceOffset, sliceSize] :
4600 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4601 int64_t sliceMaskDimSize = std::max(
4602 static_cast<int64_t>(0),
4603 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4604 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4605 }
4606 // Add unchanged dimensions.
4607 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4608 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4609 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4610 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4611 // region is a conjunction of mask dim intervals).
4612 if (llvm::is_contained(sliceMaskDimSizes, 0))
4613 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4614
4615 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4616 // region.
4617 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4618 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4619 sliceMaskDimSizes);
4620 return success();
4621 }
4622};
4623
4624// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4625// BroadcastOp(ExtractStrideSliceOp).
4626class StridedSliceBroadcast final
4627 : public OpRewritePattern<ExtractStridedSliceOp> {
4628public:
4629 using Base::Base;
4630
4631 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4632 PatternRewriter &rewriter) const override {
4633 auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4634 if (!broadcast)
4635 return failure();
4636 auto srcVecType =
4637 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4638 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4639 auto dstVecType = llvm::cast<VectorType>(op.getType());
4640 unsigned dstRank = dstVecType.getRank();
4641 unsigned rankDiff = dstRank - srcRank;
4642 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4643 // (n -> m with n > m). If they are originally both broadcasted *and*
4644 // sliced, this can be simplified to just broadcasting.
4645 bool needsSlice = false;
4646 for (unsigned i = 0; i < srcRank; i++) {
4647 if (srcVecType.getDimSize(i) != 1 &&
4648 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4649 needsSlice = true;
4650 break;
4651 }
4652 }
4653 Value source = broadcast.getSource();
4654 if (needsSlice) {
4655 SmallVector<int64_t> offsets =
4656 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4657 SmallVector<int64_t> sizes =
4658 getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4659 for (unsigned i = 0; i < srcRank; i++) {
4660 if (srcVecType.getDimSize(i) == 1) {
4661 // In case this dimension was broadcasted *and* sliced, the offset
4662 // and size need to be updated now that there is no broadcast before
4663 // the slice.
4664 offsets[i] = 0;
4665 sizes[i] = 1;
4666 }
4667 }
4668 source = ExtractStridedSliceOp::create(
4669 rewriter, op->getLoc(), source, offsets, sizes,
4670 getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4671 }
4672 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4673 return success();
4674 }
4675};
4676
4677/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4678class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4679public:
4680 using Base::Base;
4681
4682 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4683 PatternRewriter &rewriter) const override {
4684
4685 Value splat = getScalarSplatSource(op.getSource());
4686 if (!splat)
4687 return failure();
4688 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4689 return success();
4690 }
4691};
4692
4693/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4694/// slice is contiguous, into extract and shape_cast.
4695///
4696/// Example:
4697/// Before:
4698/// %1 = vector.extract_strided_slice %arg0 {
4699/// offsets = [0, 0, 0, 0, 0],
4700/// sizes = [1, 1, 1, 1, 8],
4701/// strides = [1, 1, 1, 1, 1]
4702/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4703/// After:
4704/// %0 = vector.extract %arg0[0, 0, 0, 0]
4705/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4706/// %1 = vector.shape_cast %0
4707/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4708///
4709class ContiguousExtractStridedSliceToExtract final
4710 : public OpRewritePattern<ExtractStridedSliceOp> {
4711public:
4712 using Base::Base;
4713
4714 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4715 PatternRewriter &rewriter) const override {
4716 if (op.hasNonUnitStrides())
4717 return failure();
4718 Value source = op.getOperand();
4719 auto sourceType = cast<VectorType>(source.getType());
4720 if (sourceType.isScalable() || sourceType.getRank() == 0)
4721 return failure();
4722
4723 // Compute the number of offsets to pass to ExtractOp::build. That is the
4724 // difference between the source rank and the desired slice rank. We walk
4725 // the dimensions from innermost out, and stop when the next slice dimension
4726 // is not full-size.
4727 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4728 int numOffsets;
4729 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4730 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4731 break;
4732 }
4733
4734 // If the created extract op would have no offsets, then this whole
4735 // extract_strided_slice is the identity and should have been handled by
4736 // other canonicalizations.
4737 if (numOffsets == 0)
4738 return failure();
4739
4740 // If not even the inner-most dimension is full-size, this op can't be
4741 // rewritten as an ExtractOp.
4742 if (numOffsets == sourceType.getRank() &&
4743 static_cast<int>(sizes.size()) == sourceType.getRank())
4744 return failure();
4745
4746 // The outer dimensions must have unit size.
4747 for (int i = 0; i < numOffsets; ++i) {
4748 if (sizes[i] != 1)
4749 return failure();
4750 }
4751
4752 // Avoid generating slices that have leading unit dimensions. The shape_cast
4753 // op that we create below would take bad generic fallback patterns
4754 // (ShapeCastOpRewritePattern).
4755 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4756 sizes[numOffsets] == 1) {
4757 ++numOffsets;
4758 }
4759
4760 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4761 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4762 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4763 extractOffsets);
4764 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4765 return success();
4766 }
4767};
4768
4769} // namespace
4770
4771void ExtractStridedSliceOp::getCanonicalizationPatterns(
4772 RewritePatternSet &results, MLIRContext *context) {
4773 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4774 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4775 results.add<StridedSliceFolder, StridedSliceCreateMaskFolder,
4776 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4777 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4778 context);
4779}
4780
4781//===----------------------------------------------------------------------===//
4782// TransferReadOp
4783//===----------------------------------------------------------------------===//
4784
4785/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4786void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4787 VectorType vectorType, Value source,
4788 ValueRange indices, std::optional<Value> padding,
4789 AffineMapAttr permutationMapAttr,
4790 /*optional*/ ArrayAttr inBoundsAttr) {
4791
4792 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4793 if (!padding)
4794 padding = ub::PoisonOp::create(builder, result.location, elemType);
4795 build(builder, result, vectorType, source, indices, permutationMapAttr,
4796 *padding, /*mask=*/Value(), inBoundsAttr);
4797}
4798
4799/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4800void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4801 VectorType vectorType, Value source,
4802 ValueRange indices, std::optional<Value> padding,
4803 AffineMap permutationMap,
4804 std::optional<ArrayRef<bool>> inBounds) {
4805 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4806 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4807 ? builder.getBoolArrayAttr(inBounds.value())
4808 : builder.getBoolArrayAttr(
4809 SmallVector<bool>(vectorType.getRank(), false));
4810 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4811 if (!padding)
4812 padding = ub::PoisonOp::create(builder, result.location, elemType);
4813 build(builder, result, vectorType, source, indices, *padding,
4814 permutationMapAttr, inBoundsAttr);
4815}
4816
4817/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4818void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4819 VectorType vectorType, Value source,
4820 ValueRange indices, std::optional<Value> padding,
4821 std::optional<ArrayRef<bool>> inBounds) {
4822 AffineMap permutationMap = getTransferMinorIdentityMap(
4823 llvm::cast<ShapedType>(source.getType()), vectorType);
4824 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4825 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4826 ? builder.getBoolArrayAttr(inBounds.value())
4827 : builder.getBoolArrayAttr(
4828 SmallVector<bool>(vectorType.getRank(), false));
4829 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4830 if (!padding)
4831 padding = ub::PoisonOp::create(builder, result.location, elemType);
4832 build(builder, result, vectorType, source, indices, permutationMapAttr,
4833 *padding,
4834 /*mask=*/Value(), inBoundsAttr);
4835}
4836
4837template <typename EmitFun>
4838static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4839 EmitFun emitOpError) {
4840 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4841 for (auto expr : permutationMap.getResults()) {
4842 auto dim = dyn_cast<AffineDimExpr>(expr);
4843 auto zero = dyn_cast<AffineConstantExpr>(expr);
4844 if (zero) {
4845 if (zero.getValue() != 0) {
4846 return emitOpError(
4847 "requires a projected permutation_map (at most one dim or the zero "
4848 "constant can appear in each result)");
4849 }
4850 continue;
4851 }
4852 if (!dim) {
4853 return emitOpError("requires a projected permutation_map (at most one "
4854 "dim or the zero constant can appear in each result)");
4855 }
4856 if (seen[dim.getPosition()]) {
4857 return emitOpError(
4858 "requires a permutation_map that is a permutation (found one dim "
4859 "used more than once)");
4860 }
4861 seen[dim.getPosition()] = true;
4862 }
4863 return success();
4864}
4865
4866static LogicalResult
4867verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4868 VectorType vectorType, VectorType maskType,
4869 VectorType inferredMaskType, AffineMap permutationMap,
4870 ArrayAttr inBounds) {
4871 if (op->hasAttr("masked")) {
4872 return op->emitOpError("masked attribute has been removed. "
4873 "Use in_bounds instead.");
4874 }
4875
4876 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4877 return op->emitOpError(
4878 "requires source to be a memref or ranked tensor type");
4879
4880 auto elementType = shapedType.getElementType();
4881 DataLayout dataLayout = DataLayout::closest(op);
4882 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4883 // Memref or tensor has vector element type.
4884 unsigned sourceVecSize =
4885 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4886 vectorElementType.getShape().back();
4887 unsigned resultVecSize =
4888 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4889 vectorType.getShape().back();
4890 if (resultVecSize % sourceVecSize != 0)
4891 return op->emitOpError(
4892 "requires the bitwidth of the minor 1-D vector to be an integral "
4893 "multiple of the bitwidth of the minor 1-D vector of the source");
4894
4895 unsigned sourceVecEltRank = vectorElementType.getRank();
4896 unsigned resultVecRank = vectorType.getRank();
4897 if (sourceVecEltRank > resultVecRank)
4898 return op->emitOpError(
4899 "requires source vector element and vector result ranks to match.");
4900 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4901 // Check that permutation map results match 'rankOffset' of vector type.
4902 if (permutationMap.getNumResults() != rankOffset)
4903 return op->emitOpError("requires a permutation_map with result dims of "
4904 "the same rank as the vector type");
4905
4906 if (maskType)
4907 return op->emitOpError("does not support masks with vector element type");
4908 } else {
4909 // Memref or tensor has scalar element type.
4910 unsigned minorSize =
4911 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4912 unsigned resultVecSize =
4913 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4914 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4915 return op->emitOpError(
4916 "requires the bitwidth of the minor 1-D vector to be an integral "
4917 "multiple of the bitwidth of the source element type");
4918
4919 // Check that permutation map results match rank of vector type.
4920 if (permutationMap.getNumResults() != vectorType.getRank())
4921 return op->emitOpError("requires a permutation_map with result dims of "
4922 "the same rank as the vector type");
4923 }
4924
4925 if (permutationMap.getNumSymbols() != 0)
4926 return op->emitOpError("requires permutation_map without symbols");
4927
4928 if (permutationMap.getNumInputs() != shapedType.getRank())
4929 return op->emitOpError("requires a permutation_map with input dims of the "
4930 "same rank as the source type");
4931
4932 if (maskType && maskType != inferredMaskType)
4933 return op->emitOpError("inferred mask type (")
4934 << inferredMaskType << ") and mask operand type (" << maskType
4935 << ") don't match";
4936
4937 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4938 return op->emitOpError("expects the in_bounds attr of same rank "
4939 "as permutation_map results: ")
4940 << AffineMapAttr::get(permutationMap)
4941 << " vs inBounds of size: " << inBounds.size();
4942
4943 return success();
4944}
4945
4946static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4947 SmallVector<StringRef, 3> elidedAttrs;
4948 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4949 if (op.getPermutationMap().isMinorIdentity())
4950 elidedAttrs.push_back(op.getPermutationMapAttrName());
4951 // Elide in_bounds attribute if all dims are out-of-bounds.
4952 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4953 elidedAttrs.push_back(op.getInBoundsAttrName());
4954 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4955}
4956
4957void TransferReadOp::print(OpAsmPrinter &p) {
4958 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
4959 if (getMask())
4960 p << ", " << getMask();
4961 printTransferAttrs(p, *this);
4962 p << " : " << getShapedType() << ", " << getVectorType();
4963}
4964
4965VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4966 AffineMap permMap) {
4967 auto i1Type = IntegerType::get(permMap.getContext(), 1);
4968 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4969 assert(invPermMap && "Inversed permutation map couldn't be computed");
4970 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4971
4972 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4973 // 0-D mask into a single-element 1-D mask.
4974 if (maskShape.empty())
4975 maskShape.push_back(1);
4976
4977 SmallVector<bool> scalableDims =
4978 applyPermutationMap(invPermMap, vecType.getScalableDims());
4979
4980 return VectorType::get(maskShape, i1Type, scalableDims);
4981}
4982
4983ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4984 auto &builder = parser.getBuilder();
4985 SMLoc typesLoc;
4991 // Parsing with support for paddingValue.
4992 if (parser.parseOperand(sourceInfo) ||
4994 parser.parseComma() || parser.parseOperand(paddingInfo))
4995 return failure();
4996 ParseResult hasMask = parser.parseOptionalComma();
4997 if (hasMask.succeeded()) {
4998 if (parser.parseOperand(maskInfo))
4999 return failure();
5000 }
5001 if (parser.parseOptionalAttrDict(result.attributes) ||
5002 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5003 return failure();
5004 if (types.size() != 2)
5005 return parser.emitError(typesLoc, "requires two types");
5006 auto indexType = builder.getIndexType();
5007 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5008 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5009 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5010 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5011 if (!vectorType)
5012 return parser.emitError(typesLoc, "requires vector type");
5013 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
5014 Attribute permMapAttr = result.attributes.get(permMapAttrName);
5015 AffineMap permMap;
5016 if (!permMapAttr) {
5017 if (shapedType.getRank() <
5018 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5019 return parser.emitError(typesLoc,
5020 "expected a custom permutation_map when "
5021 "rank(source) != rank(destination)");
5022 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5023 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5024 } else {
5025 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5026 }
5027 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
5028 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5029 if (!inBoundsAttr) {
5030 result.addAttribute(inBoundsAttrName,
5031 builder.getBoolArrayAttr(
5032 SmallVector<bool>(permMap.getNumResults(), false)));
5033 }
5034 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5035 parser.resolveOperands(indexInfo, indexType, result.operands) ||
5036 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
5037 result.operands))
5038 return failure();
5039 if (hasMask.succeeded()) {
5040 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5041 return parser.emitError(
5042 maskInfo.location, "does not support masks with vector element type");
5043 if (vectorType.getRank() != permMap.getNumResults()) {
5044 return parser.emitError(typesLoc,
5045 "expected the same rank for the vector and the "
5046 "results of the permutation map");
5047 }
5048 // Instead of adding the mask type as an op type, compute it based on the
5049 // vector type and the permutation map (to keep the type signature small).
5050 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5051 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5052 return failure();
5053 }
5054 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5055 builder.getDenseI32ArrayAttr(
5056 {1, static_cast<int32_t>(indexInfo.size()), 1,
5057 static_cast<int32_t>(hasMask.succeeded())}));
5058 return parser.addTypeToList(vectorType, result.types);
5059}
5060
5061LogicalResult TransferReadOp::verify() {
5062 // Consistency of elemental types in source and vector.
5063 ShapedType shapedType = getShapedType();
5064 VectorType vectorType = getVectorType();
5065 VectorType maskType = getMaskType();
5066 auto paddingType = getPadding().getType();
5067 auto permutationMap = getPermutationMap();
5068 VectorType inferredMaskType =
5069 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5070 : VectorType();
5071 auto sourceElementType = shapedType.getElementType();
5072
5073 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
5074 return emitOpError("requires ") << shapedType.getRank() << " indices";
5075
5076 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5077 shapedType, vectorType, maskType,
5078 inferredMaskType, permutationMap, getInBounds())))
5079 return failure();
5080
5081 if (auto sourceVectorElementType =
5082 llvm::dyn_cast<VectorType>(sourceElementType)) {
5083 // Source has vector element type.
5084 // Check that 'sourceVectorElementType' and 'paddingType' types match.
5085 if (sourceVectorElementType != paddingType)
5086 return emitOpError(
5087 "requires source element type and padding type to match.");
5088
5089 } else {
5090 // Check that 'paddingType' is valid to store in a vector type.
5091 if (!VectorType::isValidElementType(paddingType))
5092 return emitOpError("requires valid padding vector elemental type");
5093
5094 // Check that padding type and vector element types match.
5095 if (paddingType != sourceElementType)
5096 return emitOpError(
5097 "requires formal padding and source of the same elemental type");
5098 }
5099
5100 return verifyPermutationMap(permutationMap,
5101 [&](Twine t) { return emitOpError(t); });
5102}
5103
5104// MaskableOpInterface methods.
5105
5106/// Returns the mask type expected by this operation. Mostly used for
5107/// verification purposes. It requires the operation to be vectorized."
5108Type TransferReadOp::getExpectedMaskType() {
5109 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5110}
5111
5112//===----------------------------------------------------------------------===//
5113// TransferReadOp: VectorTransferOpInterface methods.
5114//===----------------------------------------------------------------------===//
5115VectorType TransferReadOp::getVectorType() {
5116 return cast<VectorType>(getVector().getType());
5117}
5118
5119template <typename TransferOp>
5120static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
5121 // TODO: support more aggressive createOrFold on:
5122 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
5123 if (op.getShapedType().isDynamicDim(indicesIdx))
5124 return false;
5125 Value index = op.getIndices()[indicesIdx];
5126 std::optional<int64_t> cstOp = getConstantIntValue(index);
5127 if (!cstOp.has_value())
5128 return false;
5129
5130 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5131 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5132
5133 return cstOp.value() + vectorSize <= sourceSize;
5134}
5135
5136template <typename TransferOp>
5137static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
5138 // TODO: support 0-d corner case.
5139 // TODO: Be less conservative.
5140 if (op.getTransferRank() == 0)
5141 return failure();
5142 AffineMap permutationMap = op.getPermutationMap();
5143 bool changed = false;
5144 SmallVector<bool, 4> newInBounds;
5145 newInBounds.reserve(op.getTransferRank());
5146 // Idxs of non-bcast dims - used when analysing bcast dims.
5147 SmallVector<unsigned> nonBcastDims;
5148
5149 // 1. Process non-broadcast dims
5150 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5151 // 1.1. Already marked as in-bounds, nothing to see here.
5152 if (op.isDimInBounds(i)) {
5153 newInBounds.push_back(true);
5154 continue;
5155 }
5156 // 1.2. Currently out-of-bounds, check whether we can statically determine
5157 // it is inBounds.
5158 bool inBounds = false;
5159 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5160 if (dimExpr) {
5161 inBounds = isInBounds(op, /*resultIdx=*/i,
5162 /*indicesIdx=*/dimExpr.getPosition());
5163 nonBcastDims.push_back(i);
5164 }
5165
5166 newInBounds.push_back(inBounds);
5167 // We commit the pattern if it is "more inbounds".
5168 changed |= inBounds;
5169 }
5170
5171 // 2. Handle broadcast dims
5172 // If all non-broadcast dims are "in bounds", then all bcast dims should be
5173 // "in bounds" as well.
5174 bool allNonBcastDimsInBounds = llvm::all_of(
5175 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5176 if (allNonBcastDimsInBounds) {
5177 for (size_t idx : permutationMap.getBroadcastDims()) {
5178 changed |= !newInBounds[idx];
5179 newInBounds[idx] = true;
5180 }
5181 }
5182
5183 if (!changed)
5184 return failure();
5185 // OpBuilder is only used as a helper to build an I64ArrayAttr.
5186 OpBuilder b(op.getContext());
5187 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5188 return success();
5189}
5190
5191template <typename TransferOp>
5192static LogicalResult foldTransferFullMask(TransferOp op) {
5193 auto mask = op.getMask();
5194 if (!mask)
5195 return failure();
5196
5198 return failure();
5199
5200 op.getMaskMutable().clear();
5201 return success();
5202}
5203
5204/// ```
5205/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5206/// : vector<1x4xf32>, tensor<4x4xf32>
5207/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5208/// : tensor<4x4xf32>, vector<1x4xf32>
5209/// ```
5210/// -> Folds into
5211/// ```
5212/// %v0
5213/// ```
5214static Value foldRAW(TransferReadOp readOp) {
5215 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5216 return {};
5217 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5218 while (defWrite) {
5219 if (checkSameValueRAW(defWrite, readOp))
5220 return defWrite.getVector();
5222 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5223 cast<VectorTransferOpInterface>(readOp.getOperation())))
5224 break;
5225 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5226 }
5227 return {};
5228}
5229
5230OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5231 if (Value vec = foldRAW(*this))
5232 return vec;
5233 /// transfer_read(memrefcast) -> transfer_read
5234 if (succeeded(foldTransferInBoundsAttribute(*this)))
5235 return getResult();
5236 if (succeeded(foldTransferFullMask(*this)))
5237 return getResult();
5238 if (succeeded(memref::foldMemRefCast(*this)))
5239 return getResult();
5240 if (succeeded(tensor::foldTensorCast(*this)))
5241 return getResult();
5242 return OpFoldResult();
5243}
5244
5245std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5246 return llvm::to_vector<4>(getVectorType().getShape());
5247}
5248
5249void TransferReadOp::getEffects(
5250 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5251 &effects) {
5252 if (llvm::isa<MemRefType>(getShapedType()))
5253 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5254 SideEffects::DefaultResource::get());
5255}
5256
5257Speculation::Speculatability TransferReadOp::getSpeculatability() {
5258 if (hasPureTensorSemantics())
5261}
5262
5263/// Given a projected permutation, inverse an affine map, making the unused dims
5264/// 0 in the result.
5265static AffineMap inverseWithUnusedDims(AffineMap map) {
5266 assert(map.isProjectedPermutation() &&
5267 "expected a projected permutation map");
5268 SmallVector<AffineExpr> results(map.getNumInputs(),
5270 for (auto [idx, result] : llvm::enumerate(map.getResults())) {
5271 // We should only have dim exprs because this is a projected permutation.
5272 int64_t pos = cast<AffineDimExpr>(result).getPosition();
5273 results[pos] = getAffineDimExpr(idx, map.getContext());
5274 }
5275 return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
5276 results, map.getContext());
5277}
5278
5279namespace {
5280/// Store to load forwarding for transfer operations with permuation maps.
5281/// Even if the permutation maps are different we can still propagate the store
5282/// into the load if the size of the dimensions read and written match. Then we
5283/// can replace the transfer_read + transfer_write by vector.broadcast and
5284/// vector.transpose.
5285/// Example:
5286/// ```
5287/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5288/// {in_bounds = [true, true],
5289/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5290/// vector<4x1xf32>, tensor<4x4x4xf32>
5291/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5292/// {in_bounds = [true, true, true, true],
5293/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5294/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5295/// ```
5296/// To:
5297/// ```
5298/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5299/// %r = vector.transpose %0, [3, 0, 2, 1] :
5300/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5301/// ```
5302struct TransferReadAfterWriteToBroadcast
5303 : public OpRewritePattern<TransferReadOp> {
5304 using Base::Base;
5305
5306 LogicalResult matchAndRewrite(TransferReadOp readOp,
5307 PatternRewriter &rewriter) const override {
5308 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5309 if (!defWrite)
5310 return failure();
5311 // Bail if we need an alias analysis.
5312 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5313 return failure();
5314 // Bail in the masked case (too complex atm and needed to properly account
5315 // for padding).
5316 if (readOp.getMask() || defWrite.getMask())
5317 return failure();
5318 // If indices are not the same a shift may be required, bail.
5319 if (readOp.getIndices() != defWrite.getIndices())
5320 return failure();
5321 // Bail if we need a bounds analysis.
5322 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5323 return failure();
5324 // TODO: If the written transfer chunk is a superset of the read transfer
5325 // chunk we could do an extract_strided_slice.
5326 if (readOp.getTransferChunkAccessed() !=
5327 defWrite.getTransferChunkAccessed())
5328 return failure();
5329 // WriteMap: tensor -> w_vec
5330 // ReadMap: tensor -> r_vec
5331 //
5332 // inv(WriteMap): w_vec -> tensor
5333 // inv(WriteMap) o ReadMap: w_vec -> r_vec
5334 AffineMap readMap = readOp.getPermutationMap();
5335 AffineMap writeMap = defWrite.getPermutationMap();
5336 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5337 AffineMap composedMap = readMap.compose(invWriteMap);
5338 // If there are any unused dims in the composedMap, we have to drop some
5339 // unit dims from the written vector before we can do transpose(broadcast).
5340 // TODO: Support this case.
5341 if (getUnusedDimsBitVector(composedMap).any())
5342 return failure();
5343 // readVec = transpose(broadcast(writeVec))
5344 //
5345 // Build a transpose permutation for the above transpose operation.
5346 //
5347 // Treat the composed map as having extra leading dimensions which are
5348 // the broadcasted dimensions, and treat the zeros as these new broadcasted
5349 // dimensions.
5350 SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
5351 int64_t numBroadcastedDims = broadcastedDims.size();
5352 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5353 invPerm.resize(composedMap.getNumResults());
5354 for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
5355 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5356 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5357 invPerm[effectiveDim] = idx;
5358 }
5359 }
5360 // Applying the inverse permutation on the readVecTy will give us the
5361 // broadcast result type.
5362 VectorType readVecTy = readOp.getVectorType();
5363 SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
5364 auto broadcastedVecTy =
5365 VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
5366 readVecTy.getElementType(),
5367 applyPermutation(readVecTy.getScalableDims(), invPerm));
5368 // Build the transpose(broadcast) transformation.
5369 Value vec = defWrite.getVector();
5370 Location loc = readOp.getLoc();
5371 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5372 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
5373 return success();
5374 }
5375};
5376} // namespace
5377
5378void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5379 MLIRContext *context) {
5380 results.add<TransferReadAfterWriteToBroadcast>(context);
5381}
5382
5383FailureOr<std::optional<SmallVector<Value>>>
5384TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5385 if (!hasPureBufferSemantics())
5386 return failure();
5388 getResult());
5389}
5390
5391//===----------------------------------------------------------------------===//
5392// TransferWriteOp
5393//===----------------------------------------------------------------------===//
5394
5395/// 1. Builder with type inference.
5396void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5397 Value vector, Value dest, ValueRange indices,
5398 AffineMapAttr permutationMapAttr,
5399 /*optional*/ Value mask,
5400 /*optional*/ ArrayAttr inBoundsAttr) {
5401 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5402 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5403 mask, inBoundsAttr);
5404}
5405
5406/// 2. Builder with type inference that sets an empty mask (variant with attrs).
5407void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5408 Value vector, Value dest, ValueRange indices,
5409 AffineMapAttr permutationMapAttr,
5410 /*optional*/ ArrayAttr inBoundsAttr) {
5411 build(builder, result, vector, dest, indices, permutationMapAttr,
5412 /*mask=*/Value(), inBoundsAttr);
5413}
5414
5415/// 3. Builder with type inference that sets an empty mask (variant without
5416/// attrs)
5417void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5418 Value vector, Value dest, ValueRange indices,
5419 AffineMap permutationMap,
5420 std::optional<ArrayRef<bool>> inBounds) {
5421 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5422 auto inBoundsAttr =
5423 (inBounds && !inBounds.value().empty())
5424 ? builder.getBoolArrayAttr(inBounds.value())
5425 : builder.getBoolArrayAttr(SmallVector<bool>(
5426 llvm::cast<VectorType>(vector.getType()).getRank(), false));
5427 build(builder, result, vector, dest, indices, permutationMapAttr,
5428 /*mask=*/Value(), inBoundsAttr);
5429}
5430
5431/// 4. Builder with type inference that sets an empty mask and sets permutation
5432/// map to 'getMinorIdentityMap'.
5433void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5434 Value vector, Value dest, ValueRange indices,
5435 std::optional<ArrayRef<bool>> inBounds) {
5436 auto vectorType = llvm::cast<VectorType>(vector.getType());
5437 AffineMap permutationMap = getTransferMinorIdentityMap(
5438 llvm::cast<ShapedType>(dest.getType()), vectorType);
5439 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5440}
5441
5442ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5443 OperationState &result) {
5444 auto &builder = parser.getBuilder();
5445 SMLoc typesLoc;
5446 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5447 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5448 SmallVector<Type, 2> types;
5449 OpAsmParser::UnresolvedOperand maskInfo;
5450 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5451 parser.parseOperand(sourceInfo) ||
5452 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5453 return failure();
5454 ParseResult hasMask = parser.parseOptionalComma();
5455 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5456 return failure();
5457 if (parser.parseOptionalAttrDict(result.attributes) ||
5458 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5459 return failure();
5460 if (types.size() != 2)
5461 return parser.emitError(typesLoc, "requires two types");
5462 auto indexType = builder.getIndexType();
5463 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5464 if (!vectorType)
5465 return parser.emitError(typesLoc, "requires vector type");
5466 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5467 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5468 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5469 auto permMapAttrName =
5470 TransferWriteOp::getPermutationMapAttrName(result.name);
5471 auto permMapAttr = result.attributes.get(permMapAttrName);
5472 AffineMap permMap;
5473 if (!permMapAttr) {
5474 if (shapedType.getRank() <
5475 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5476 return parser.emitError(typesLoc,
5477 "expected a custom permutation_map when "
5478 "rank(source) != rank(destination)");
5479 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5480 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5481 } else {
5482 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5483 }
5484 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5485 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5486 if (!inBoundsAttr) {
5487 result.addAttribute(inBoundsAttrName,
5488 builder.getBoolArrayAttr(
5489 SmallVector<bool>(permMap.getNumResults(), false)));
5490 }
5491 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5492 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5493 parser.resolveOperands(indexInfo, indexType, result.operands))
5494 return failure();
5495 if (hasMask.succeeded()) {
5496 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5497 return parser.emitError(
5498 maskInfo.location, "does not support masks with vector element type");
5499 if (vectorType.getRank() != permMap.getNumResults()) {
5500 return parser.emitError(typesLoc,
5501 "expected the same rank for the vector and the "
5502 "results of the permutation map");
5503 }
5504 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5505 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5506 return failure();
5507 }
5508 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5509 builder.getDenseI32ArrayAttr(
5510 {1, 1, static_cast<int32_t>(indexInfo.size()),
5511 static_cast<int32_t>(hasMask.succeeded())}));
5512 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5513 parser.addTypeToList(shapedType, result.types));
5514}
5515
5516void TransferWriteOp::print(OpAsmPrinter &p) {
5517 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5518 if (getMask())
5519 p << ", " << getMask();
5520 printTransferAttrs(p, *this);
5521 p << " : " << getVectorType() << ", " << getShapedType();
5522}
5523
5524LogicalResult TransferWriteOp::verify() {
5525 // Consistency of elemental types in shape and vector.
5526 ShapedType shapedType = getShapedType();
5527 VectorType vectorType = getVectorType();
5528 VectorType maskType = getMaskType();
5529 auto permutationMap = getPermutationMap();
5530 VectorType inferredMaskType =
5531 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5532 : VectorType();
5533
5534 if (llvm::size(getIndices()) != shapedType.getRank())
5535 return emitOpError("requires ") << shapedType.getRank() << " indices";
5536
5537 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5538 // as the semantics is unclear. This can be revisited later if necessary.
5539 if (hasBroadcastDim())
5540 return emitOpError("should not have broadcast dimensions");
5541
5542 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5543 shapedType, vectorType, maskType,
5544 inferredMaskType, permutationMap, getInBounds())))
5545 return failure();
5546
5547 return verifyPermutationMap(permutationMap,
5548 [&](Twine t) { return emitOpError(t); });
5549}
5550
5551//===----------------------------------------------------------------------===//
5552// TransferWriteOp: MaskableOpInterface methods.
5553//===----------------------------------------------------------------------===//
5554
5555/// Returns the mask type expected by this operation. Mostly used for
5556/// verification purposes.
5557Type TransferWriteOp::getExpectedMaskType() {
5558 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5559}
5560
5561//===----------------------------------------------------------------------===//
5562// TransferWriteOp: VectorTransferOpInterface methods.
5563//===----------------------------------------------------------------------===//
5564Value TransferWriteOp::getVector() { return getOperand(0); }
5565VectorType TransferWriteOp::getVectorType() {
5566 return cast<VectorType>(getValueToStore().getType());
5567}
5568
5569//===----------------------------------------------------------------------===//
5570// TransferWriteOp: fold methods.
5571//===----------------------------------------------------------------------===//
5572/// Fold:
5573/// ```
5574/// %t1 = ...
5575/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5576/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5577/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5578/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5579/// ```
5580///
5581/// into:
5582///
5583/// ```
5584/// %t0
5585/// ```
5586///
5587/// The producer of t1 may or may not be DCE'd depending on whether it is a
5588/// block argument or has side effects.
5589static LogicalResult foldReadInitWrite(TransferWriteOp write,
5590 ArrayRef<Attribute>,
5591 SmallVectorImpl<OpFoldResult> &results) {
5592 // TODO: support 0-d corner case.
5593 if (write.getTransferRank() == 0)
5594 return failure();
5595 auto rankedTensorType =
5596 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5597 // If not operating on tensors, bail.
5598 if (!rankedTensorType)
5599 return failure();
5600 // If no read, bail.
5601 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5602 if (!read)
5603 return failure();
5604 // TODO: support 0-d corner case.
5605 if (read.getTransferRank() == 0)
5606 return failure();
5607 // For now, only accept minor identity. Future: composition is minor identity.
5608 if (!read.getPermutationMap().isMinorIdentity() ||
5609 !write.getPermutationMap().isMinorIdentity())
5610 return failure();
5611 // Bail on mismatching ranks.
5612 if (read.getTransferRank() != write.getTransferRank())
5613 return failure();
5614 // Bail on potential out-of-bounds accesses.
5615 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5616 return failure();
5617 // Tensor types must be the same.
5618 if (read.getBase().getType() != rankedTensorType)
5619 return failure();
5620 // Vector types must be the same.
5621 if (read.getVectorType() != write.getVectorType())
5622 return failure();
5623 // Vector and Tensor shapes must match.
5624 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5625 return failure();
5626 // If any index is nonzero.
5627 auto isNotConstantZero = [](Value v) {
5628 auto cstOp = getConstantIntValue(v);
5629 return !cstOp.has_value() || cstOp.value() != 0;
5630 };
5631 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5632 llvm::any_of(write.getIndices(), isNotConstantZero))
5633 return failure();
5634 // Success.
5635 results.push_back(read.getBase());
5636 return success();
5637}
5638
5639static bool checkSameValueWAR(vector::TransferReadOp read,
5640 vector::TransferWriteOp write) {
5641 return read.getBase() == write.getBase() &&
5642 read.getIndices() == write.getIndices() &&
5643 read.getPermutationMap() == write.getPermutationMap() &&
5644 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5645 !write.getMask();
5646}
5647/// Fold transfer_write write after read:
5648/// ```
5649/// %t0 = ...
5650/// %v = vector.transfer_read %t0[%c0...] :
5651/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5652/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5653/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5654/// ```
5655///
5656/// into:
5657///
5658/// ```
5659/// %t0
5660/// ```
5661static LogicalResult foldWAR(TransferWriteOp write,
5662 SmallVectorImpl<OpFoldResult> &results) {
5663 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5664 return failure();
5665 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5666 if (!read)
5667 return failure();
5668
5669 if (!checkSameValueWAR(read, write))
5670 return failure();
5671 results.push_back(read.getBase());
5672 return success();
5673}
5674
5675LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5676 SmallVectorImpl<OpFoldResult> &results) {
5677 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5678 return success();
5679 if (succeeded(foldWAR(*this, results)))
5680 return success();
5681 if (succeeded(foldTransferInBoundsAttribute(*this)))
5682 return success();
5683 if (succeeded(foldTransferFullMask(*this)))
5684 return success();
5685 return memref::foldMemRefCast(*this);
5686}
5687
5688//===----------------------------------------------------------------------===//
5689// TransferWriteOp: other methods.
5690//===----------------------------------------------------------------------===//
5691std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5692 return llvm::to_vector<4>(getVectorType().getShape());
5693}
5694
5695void TransferWriteOp::getEffects(
5696 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5697 &effects) {
5698 if (llvm::isa<MemRefType>(getShapedType()))
5699 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5700 SideEffects::DefaultResource::get());
5701}
5702
5703Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5704 if (hasPureTensorSemantics())
5707}
5708
5709namespace {
5710/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5711/// DCE
5712/// ```
5713/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5714/// : vector<1x4xf32>, tensor<4x4xf32>
5715/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5716/// : vector<1x4xf32>, tensor<4x4xf32>
5717/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5718/// : vector<1x4xf32>, tensor<4x4xf32>
5719/// ```
5720///
5721/// into:
5722///
5723/// ```
5724/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5725/// : vector<1x4xf32>, tensor<4x4xf32>
5726/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5727/// : vector<1x4xf32>, tensor<4x4xf32>
5728/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5729/// : vector<1x4xf32>, tensor<4x4xf32>
5730/// ```
5731///
5732/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5733/// any other uses.
5734class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5735public:
5736 using Base::Base;
5737 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5738 PatternRewriter &rewriter) const override {
5739 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5740 return failure();
5741 vector::TransferWriteOp writeToModify = writeOp;
5742
5743 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5744 while (defWrite) {
5745 if (checkSameValueWAW(writeOp, defWrite)) {
5746 rewriter.modifyOpInPlace(writeToModify, [&]() {
5747 writeToModify.getBaseMutable().assign(defWrite.getBase());
5748 });
5749 return success();
5750 }
5752 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5753 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5754 break;
5755 // If the previous write op doesn't have any other use we an safely look
5756 // at the previous store to see if it can be removed.
5757 if (!defWrite->hasOneUse())
5758 break;
5759 writeToModify = defWrite;
5760 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5761 }
5762 return failure();
5763 }
5764};
5765
5766/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5767/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5768/// overwritten and inserted into another tensor. After this rewrite, the
5769/// operations bufferize in-place since all of them work on the same slice.
5770///
5771/// For example:
5772/// ```mlir
5773/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5774/// : vector<8x16xf32>, tensor<8x16xf32>
5775/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5776/// : tensor<8x16xf32> to tensor<?x?xf32>
5777/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5778/// : tensor<?x?xf32> into tensor<27x37xf32>
5779/// ```
5780/// folds to
5781/// ```mlir
5782/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5783/// : tensor<27x37xf32> to tensor<?x?xf32>
5784/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5785/// : vector<8x16xf32>, tensor<?x?xf32>
5786/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5787/// : tensor<?x?xf32> into tensor<27x37xf32>
5788/// ```
5789struct SwapExtractSliceOfTransferWrite
5790 : public OpRewritePattern<tensor::InsertSliceOp> {
5791public:
5792 using Base::Base;
5793
5794 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5795 PatternRewriter &rewriter) const override {
5796 if (!insertOp.hasUnitStride())
5797 return failure();
5798 auto extractOp =
5799 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5800 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5801 return failure();
5802 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5803 if (!transferOp || !transferOp->hasOneUse())
5804 return failure();
5805
5806 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5807 // rank-reducing.
5808 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5809 return rewriter.notifyMatchFailure(insertOp,
5810 "use-def chain is rank-reducing");
5811 }
5812
5813 // Fail if tensor::ExtractSliceOp has non-zero offset.
5814 if (!extractOp.hasZeroOffset()) {
5815 return rewriter.notifyMatchFailure(insertOp,
5816 "ExtractSliceOp has non-zero offset");
5817 }
5818
5819 // Fail if tensor::TransferWriteOp has non-zero offset.
5820 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5821 return getConstantIntValue(value) == static_cast<int64_t>(0);
5822 })) {
5823 return rewriter.notifyMatchFailure(insertOp,
5824 "TranferWriteOp has non-zero offset");
5825 }
5826
5827 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5828 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5829 return rewriter.notifyMatchFailure(
5830 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5831 }
5832
5833 for (auto [insertSize, extractSize] :
5834 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5835 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5836 return rewriter.notifyMatchFailure(
5837 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5838 }
5839 }
5840
5841 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5842 assert(transferOp.getVectorType().hasStaticShape() &&
5843 "expected vector to have a static shape");
5844 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5845 SmallVector<int64_t> resultShape = applyPermutationMap(
5846 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5847 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5848 return rewriter.notifyMatchFailure(
5849 insertOp, "TransferWriteOp may not write the full tensor.");
5850 }
5851
5852 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5853 // Set all in_bounds to false and let the folder infer them.
5854 SmallVector<bool> newInBounds(vectorShape.size(), false);
5855 auto newExtractOp = tensor::ExtractSliceOp::create(
5856 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5857 insertOp.getDest(), insertOp.getMixedOffsets(),
5858 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5859 auto newTransferWriteOp = TransferWriteOp::create(
5860 rewriter, transferOp.getLoc(), transferOp.getVector(),
5861 newExtractOp.getResult(), transferOp.getIndices(),
5862 transferOp.getPermutationMapAttr(),
5863 rewriter.getBoolArrayAttr(newInBounds));
5864 rewriter.modifyOpInPlace(insertOp, [&]() {
5865 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5866 });
5867 return success();
5868 }
5869};
5870
5871} // namespace
5872
5873void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5874 MLIRContext *context) {
5875 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5876}
5877
5878FailureOr<std::optional<SmallVector<Value>>>
5879TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5880 if (!hasPureBufferSemantics())
5881 return failure();
5883 ValueRange());
5884}
5885
5886//===----------------------------------------------------------------------===//
5887// LoadOp
5888//===----------------------------------------------------------------------===//
5889
5890static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5891 VectorType vecTy,
5892 MemRefType memRefTy) {
5893 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5894 // need any strides limitations.
5895 if (!vecTy.isScalable() &&
5896 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5897 return success();
5898
5899 if (!memRefTy.isLastDimUnitStride())
5900 return op->emitOpError("most minor memref dim must have unit stride");
5901 return success();
5902}
5903
5904LogicalResult vector::LoadOp::verify() {
5905 VectorType resVecTy = getVectorType();
5906 MemRefType memRefTy = getMemRefType();
5907
5908 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5909 return failure();
5910
5911 if (memRefTy.getRank() < resVecTy.getRank())
5912 return emitOpError(
5913 "destination memref has lower rank than the result vector");
5914
5915 // Checks for vector memrefs.
5916 Type memElemTy = memRefTy.getElementType();
5917 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5918 if (memVecTy != resVecTy)
5919 return emitOpError("base memref and result vector types should match");
5920 memElemTy = memVecTy.getElementType();
5921 }
5922
5923 if (resVecTy.getElementType() != memElemTy)
5924 return emitOpError("base and result element types should match");
5925 if (llvm::size(getIndices()) != memRefTy.getRank())
5926 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5927 return success();
5928}
5929
5930OpFoldResult LoadOp::fold(FoldAdaptor) {
5931 if (succeeded(memref::foldMemRefCast(*this)))
5932 return getResult();
5933 return OpFoldResult();
5934}
5935
5936std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5937 return llvm::to_vector<4>(getVectorType().getShape());
5938}
5939
5940FailureOr<std::optional<SmallVector<Value>>>
5941LoadOp::bubbleDownCasts(OpBuilder &builder) {
5943 getResult());
5944}
5945
5946//===----------------------------------------------------------------------===//
5947// StoreOp
5948//===----------------------------------------------------------------------===//
5949
5950LogicalResult vector::StoreOp::verify() {
5951 VectorType valueVecTy = getVectorType();
5952 MemRefType memRefTy = getMemRefType();
5953
5954 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5955 return failure();
5956
5957 if (memRefTy.getRank() < valueVecTy.getRank())
5958 return emitOpError("source memref has lower rank than the vector to store");
5959
5960 // Checks for vector memrefs.
5961 Type memElemTy = memRefTy.getElementType();
5962 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5963 if (memVecTy != valueVecTy)
5964 return emitOpError(
5965 "base memref and valueToStore vector types should match");
5966 memElemTy = memVecTy.getElementType();
5967 }
5968
5969 if (valueVecTy.getElementType() != memElemTy)
5970 return emitOpError("base and valueToStore element type should match");
5971 if (llvm::size(getIndices()) != memRefTy.getRank())
5972 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5973 return success();
5974}
5975
5976LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5977 SmallVectorImpl<OpFoldResult> &results) {
5978 return memref::foldMemRefCast(*this);
5979}
5980
5981std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5982 return llvm::to_vector<4>(getVectorType().getShape());
5983}
5984
5985FailureOr<std::optional<SmallVector<Value>>>
5986StoreOp::bubbleDownCasts(OpBuilder &builder) {
5988 ValueRange());
5989}
5990
5991//===----------------------------------------------------------------------===//
5992// MaskedLoadOp
5993//===----------------------------------------------------------------------===//
5994
5995LogicalResult MaskedLoadOp::verify() {
5996 VectorType maskVType = getMaskVectorType();
5997 VectorType passVType = getPassThruVectorType();
5998 VectorType resVType = getVectorType();
5999 MemRefType memType = getMemRefType();
6000
6001 if (failed(
6002 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6003 return failure();
6004 if (llvm::size(getIndices()) != memType.getRank())
6005 return emitOpError("requires ") << memType.getRank() << " indices";
6006 if (resVType.getShape() != maskVType.getShape())
6007 return emitOpError("expected result shape to match mask shape");
6008 if (resVType != passVType)
6009 return emitOpError("expected pass_thru of same type as result type");
6010 return success();
6011}
6012
6013namespace {
6014class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
6015public:
6016 using Base::Base;
6017 LogicalResult matchAndRewrite(MaskedLoadOp load,
6018 PatternRewriter &rewriter) const override {
6019 switch (getMaskFormat(load.getMask())) {
6021 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6022 load, load.getType(), load.getBase(), load.getIndices());
6023 return success();
6025 rewriter.replaceOp(load, load.getPassThru());
6026 return success();
6028 return failure();
6029 }
6030 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
6031 }
6032};
6033} // namespace
6034
6035void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6036 MLIRContext *context) {
6037 results.add<MaskedLoadFolder>(context);
6038}
6039
6040OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6041 if (succeeded(memref::foldMemRefCast(*this)))
6042 return getResult();
6043 return OpFoldResult();
6044}
6045
6046FailureOr<std::optional<SmallVector<Value>>>
6047MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6049 getResult());
6050}
6051
6052//===----------------------------------------------------------------------===//
6053// MaskedStoreOp
6054//===----------------------------------------------------------------------===//
6055
6056LogicalResult MaskedStoreOp::verify() {
6057 VectorType maskVType = getMaskVectorType();
6058 VectorType valueVType = getVectorType();
6059 MemRefType memType = getMemRefType();
6060
6061 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6062 "valueToStore")))
6063 return failure();
6064 if (llvm::size(getIndices()) != memType.getRank())
6065 return emitOpError("requires ") << memType.getRank() << " indices";
6066 if (valueVType.getShape() != maskVType.getShape())
6067 return emitOpError("expected valueToStore shape to match mask shape");
6068 return success();
6069}
6070
6071namespace {
6072class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
6073public:
6074 using Base::Base;
6075 LogicalResult matchAndRewrite(MaskedStoreOp store,
6076 PatternRewriter &rewriter) const override {
6077 switch (getMaskFormat(store.getMask())) {
6079 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6080 store, store.getValueToStore(), store.getBase(), store.getIndices());
6081 return success();
6083 rewriter.eraseOp(store);
6084 return success();
6086 return failure();
6087 }
6088 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
6089 }
6090};
6091} // namespace
6092
6093void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6094 MLIRContext *context) {
6095 results.add<MaskedStoreFolder>(context);
6096}
6097
6098LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6099 SmallVectorImpl<OpFoldResult> &results) {
6100 return memref::foldMemRefCast(*this);
6101}
6102
6103FailureOr<std::optional<SmallVector<Value>>>
6104MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6106 ValueRange());
6107}
6108
6109//===----------------------------------------------------------------------===//
6110// GatherOp
6111//===----------------------------------------------------------------------===//
6112
6113LogicalResult GatherOp::verify() {
6114 VectorType indVType = getIndexVectorType();
6115 VectorType maskVType = getMaskVectorType();
6116 VectorType resVType = getVectorType();
6117 ShapedType baseType = getBaseType();
6118
6119 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6120 return emitOpError("requires base to be a memref or ranked tensor type");
6121
6122 if (failed(
6123 verifyElementTypesMatch(*this, baseType, resVType, "base", "result")))
6124 return failure();
6125 if (llvm::size(getOffsets()) != baseType.getRank())
6126 return emitOpError("requires ") << baseType.getRank() << " indices";
6127 if (resVType.getShape() != indVType.getShape())
6128 return emitOpError("expected result dim to match indices dim");
6129 if (resVType.getShape() != maskVType.getShape())
6130 return emitOpError("expected result dim to match mask dim");
6131 if (resVType != getPassThruVectorType())
6132 return emitOpError("expected pass_thru of same type as result type");
6133 return success();
6134}
6135
6136// MaskableOpInterface methods.
6137
6138/// Returns the mask type expected by this operation. Mostly used for
6139/// verification purposes. It requires the operation to be vectorized."
6140Type GatherOp::getExpectedMaskType() {
6141 auto vecType = this->getIndexVectorType();
6142 return VectorType::get(vecType.getShape(),
6143 IntegerType::get(vecType.getContext(), /*width=*/1),
6144 vecType.getScalableDims());
6145}
6146
6147std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6148 return llvm::to_vector<4>(getVectorType().getShape());
6149}
6150
6151/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
6152static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6153 auto vecType = dyn_cast<VectorType>(indexVec.getType());
6154 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6155 return failure();
6156
6157 if (indexVec.getDefiningOp<StepOp>())
6158 return success();
6159
6160 DenseIntElementsAttr elements;
6161 if (!matchPattern(indexVec, m_Constant(&elements)))
6162 return failure();
6163
6164 return success(
6165 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6166}
6167
6168namespace {
6169class GatherFolder final : public OpRewritePattern<GatherOp> {
6170public:
6171 using Base::Base;
6172 LogicalResult matchAndRewrite(GatherOp gather,
6173 PatternRewriter &rewriter) const override {
6174 switch (getMaskFormat(gather.getMask())) {
6176 return failure(); // no unmasked equivalent
6178 rewriter.replaceOp(gather, gather.getPassThru());
6179 return success();
6181 return failure();
6182 }
6183 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6184 }
6185};
6186
6187/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6188/// maskedload. Only 1D fixed vectors are supported for now.
6189class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6190public:
6191 using Base::Base;
6192 LogicalResult matchAndRewrite(GatherOp op,
6193 PatternRewriter &rewriter) const override {
6194 if (!isa<MemRefType>(op.getBase().getType()))
6195 return rewriter.notifyMatchFailure(op, "base must be of memref type");
6196
6197 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6198 return failure();
6199
6200 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6201 op.getOffsets(), op.getMask(),
6202 op.getPassThru());
6203 return success();
6204 }
6205};
6206} // namespace
6207
6208void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6209 MLIRContext *context) {
6210 results.add<GatherFolder, FoldContiguousGather>(context);
6211}
6212
6213FailureOr<std::optional<SmallVector<Value>>>
6214GatherOp::bubbleDownCasts(OpBuilder &builder) {
6216 getResult());
6217}
6218
6219//===----------------------------------------------------------------------===//
6220// ScatterOp
6221//===----------------------------------------------------------------------===//
6222
6223LogicalResult ScatterOp::verify() {
6224 VectorType indVType = getIndexVectorType();
6225 VectorType maskVType = getMaskVectorType();
6226 VectorType valueVType = getVectorType();
6227 ShapedType baseType = getBaseType();
6228
6229 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6230 return emitOpError("requires base to be a memref or ranked tensor type");
6231
6232 if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
6233 "valueToStore")))
6234 return failure();
6235 if (llvm::size(getOffsets()) != baseType.getRank())
6236 return emitOpError("requires ") << baseType.getRank() << " indices";
6237 if (valueVType.getShape() != indVType.getShape())
6238 return emitOpError("expected valueToStore dim to match indices dim");
6239 if (valueVType.getShape() != maskVType.getShape())
6240 return emitOpError("expected valueToStore dim to match mask dim");
6241 return success();
6242}
6243namespace {
6244class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6245public:
6246 using Base::Base;
6247 LogicalResult matchAndRewrite(ScatterOp scatter,
6248 PatternRewriter &rewriter) const override {
6249 ShapedType baseType = scatter.getBaseType();
6250 bool isMemRef = isa<MemRefType>(baseType);
6251 if (!isMemRef && !isa<RankedTensorType>(baseType))
6252 return failure();
6253
6254 // Memrefs have no result, so an all-false mask can simply erase the op.
6255 // Tensors carry the updated value, so we must replace uses with the
6256 // original base tensor instead of erasing.
6257 switch (getMaskFormat(scatter.getMask())) {
6259 return failure(); // no unmasked equivalent
6261 if (isMemRef)
6262 rewriter.eraseOp(scatter);
6263 else
6264 rewriter.replaceOp(scatter, scatter.getBase());
6265 return success();
6267 return failure();
6268 }
6269 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6270 }
6271};
6272
6273/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6274/// maskedstore. Only 1D fixed vectors are supported for now.
6275class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6276public:
6277 using Base::Base;
6278 LogicalResult matchAndRewrite(ScatterOp op,
6279 PatternRewriter &rewriter) const override {
6280 // Fold only for memrefs: the replacement uses maskedstore, which does not
6281 // support tensor bases. Tensor cases intentionally bail out.
6282 if (!isa<MemRefType>(op.getBase().getType()))
6283 return failure();
6284
6285 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6286 return failure();
6287
6288 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6289 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6290 return success();
6291 }
6292};
6293} // namespace
6294
6295void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6296 MLIRContext *context) {
6297 results.add<ScatterFolder, FoldContiguousScatter>(context);
6298}
6299
6300FailureOr<std::optional<SmallVector<Value>>>
6301ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6303 ValueRange());
6304}
6305
6306//===----------------------------------------------------------------------===//
6307// ExpandLoadOp
6308//===----------------------------------------------------------------------===//
6309
6310LogicalResult ExpandLoadOp::verify() {
6311 VectorType maskVType = getMaskVectorType();
6312 VectorType passVType = getPassThruVectorType();
6313 VectorType resVType = getVectorType();
6314 MemRefType memType = getMemRefType();
6315
6316 if (failed(
6317 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6318 return failure();
6319 if (llvm::size(getIndices()) != memType.getRank())
6320 return emitOpError("requires ") << memType.getRank() << " indices";
6321 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6322 return emitOpError("expected result dim to match mask dim");
6323 if (resVType != passVType)
6324 return emitOpError("expected pass_thru of same type as result type");
6325 return success();
6326}
6327
6328namespace {
6329class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6330public:
6331 using Base::Base;
6332 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6333 PatternRewriter &rewriter) const override {
6334 switch (getMaskFormat(expand.getMask())) {
6336 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6337 expand, expand.getType(), expand.getBase(), expand.getIndices());
6338 return success();
6340 rewriter.replaceOp(expand, expand.getPassThru());
6341 return success();
6343 return failure();
6344 }
6345 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6346 }
6347};
6348} // namespace
6349
6350void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6351 MLIRContext *context) {
6352 results.add<ExpandLoadFolder>(context);
6353}
6354
6355FailureOr<std::optional<SmallVector<Value>>>
6356ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6358 getResult());
6359}
6360
6361//===----------------------------------------------------------------------===//
6362// CompressStoreOp
6363//===----------------------------------------------------------------------===//
6364
6365LogicalResult CompressStoreOp::verify() {
6366 VectorType maskVType = getMaskVectorType();
6367 VectorType valueVType = getVectorType();
6368 MemRefType memType = getMemRefType();
6369
6370 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6371 "valueToStore")))
6372 return failure();
6373 if (llvm::size(getIndices()) != memType.getRank())
6374 return emitOpError("requires ") << memType.getRank() << " indices";
6375 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6376 return emitOpError("expected valueToStore dim to match mask dim");
6377 return success();
6378}
6379
6380namespace {
6381class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6382public:
6383 using Base::Base;
6384 LogicalResult matchAndRewrite(CompressStoreOp compress,
6385 PatternRewriter &rewriter) const override {
6386 switch (getMaskFormat(compress.getMask())) {
6388 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6389 compress, compress.getValueToStore(), compress.getBase(),
6390 compress.getIndices());
6391 return success();
6393 rewriter.eraseOp(compress);
6394 return success();
6396 return failure();
6397 }
6398 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6399 }
6400};
6401} // namespace
6402
6403void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6404 MLIRContext *context) {
6405 results.add<CompressStoreFolder>(context);
6406}
6407
6408FailureOr<std::optional<SmallVector<Value>>>
6409CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6411 ValueRange());
6412}
6413
6414//===----------------------------------------------------------------------===//
6415// ShapeCastOp
6416//===----------------------------------------------------------------------===//
6417
6418void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6419 SetIntRangeFn setResultRanges) {
6420 setResultRanges(getResult(), argRanges.front());
6421}
6422
6423std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6424 return llvm::to_vector<4>(getResultVectorType().getShape());
6425}
6426
6427LogicalResult ShapeCastOp::verify() {
6428
6429 VectorType sourceType = getSourceVectorType();
6430 VectorType resultType = getResultVectorType();
6431
6432 // Check that element type is preserved
6433 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
6434 "result")))
6435 return failure();
6436
6437 // Check that number of elements is preserved
6438 int64_t sourceNElms = sourceType.getNumElements();
6439 int64_t resultNElms = resultType.getNumElements();
6440 if (sourceNElms != resultNElms) {
6441 return emitOpError() << "has different number of elements at source ("
6442 << sourceNElms << ") and result (" << resultNElms
6443 << ")";
6444 }
6445
6446 // Check that (non-)scalability is preserved
6447 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6448 int64_t resultNScalableDims = resultType.getNumScalableDims();
6449 if (sourceNScalableDims != resultNScalableDims)
6450 return emitOpError() << "has different number of scalable dims at source ("
6451 << sourceNScalableDims << ") and result ("
6452 << resultNScalableDims << ")";
6453
6454 return success();
6455}
6456
6457/// Return true if `transpose` does not permute a pair of non-unit dims.
6458/// By `order preserving` we mean that the flattened versions of the input and
6459/// output vectors are (numerically) identical. In other words `transpose` is
6460/// effectively a shape cast.
6461static bool isOrderPreserving(TransposeOp transpose) {
6462 ArrayRef<int64_t> permutation = transpose.getPermutation();
6463 VectorType sourceType = transpose.getSourceVectorType();
6464 ArrayRef<int64_t> inShape = sourceType.getShape();
6465 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6466 auto isNonScalableUnitDim = [&](int64_t dim) {
6467 return inShape[dim] == 1 && !inDimIsScalable[dim];
6468 };
6469 int64_t current = 0;
6470 for (auto p : permutation) {
6471 if (!isNonScalableUnitDim(p)) {
6472 if (p < current) {
6473 return false;
6474 }
6475 current = p;
6476 }
6477 }
6478 return true;
6479}
6480
6481OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6482
6483 VectorType resultType = getType();
6484
6485 // No-op shape cast.
6486 if (getSource().getType() == resultType)
6487 return getSource();
6488
6489 // shape_cast(shape_cast(x)) -> shape_cast(x)
6490 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6491 setOperand(precedingShapeCast.getSource());
6492 return getResult();
6493 }
6494
6495 // shape_cast(transpose(x)) -> shape_cast(x)
6496 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6497 if (isOrderPreserving(transpose)) {
6498 setOperand(transpose.getVector());
6499 return getResult();
6500 }
6501 return {};
6502 }
6503
6504 // Y = shape_cast(broadcast(X))
6505 // -> X, if X and Y have same type
6506 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6507 if (bcastOp.getSourceType() == resultType)
6508 return bcastOp.getSource();
6509 }
6510
6511 // shape_cast(constant) -> constant
6512 if (auto denseAttr =
6513 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6514 return denseAttr.reshape(getType());
6515
6516 // shape_cast(poison) -> poison
6517 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6518 return ub::PoisonAttr::get(getContext());
6519
6520 return {};
6521}
6522
6523namespace {
6524
6525/// Helper function that computes a new vector type based on the input vector
6526/// type by removing the trailing one dims:
6527///
6528/// vector<4x1x1xi1> --> vector<4x1xi1>
6529///
6530static VectorType trimTrailingOneDims(VectorType oldType) {
6531 ArrayRef<int64_t> oldShape = oldType.getShape();
6532 ArrayRef<int64_t> newShape = oldShape;
6533
6534 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6535 ArrayRef<bool> newScalableDims = oldScalableDims;
6536
6537 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6538 newShape = newShape.drop_back(1);
6539 newScalableDims = newScalableDims.drop_back(1);
6540 }
6541
6542 // Make sure we have at least 1 dimension.
6543 // TODO: Add support for 0-D vectors.
6544 if (newShape.empty()) {
6545 newShape = oldShape.take_back();
6546 newScalableDims = oldScalableDims.take_back();
6547 }
6548
6549 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6550}
6551
6552/// Folds qualifying shape_cast(create_mask) into a new create_mask
6553///
6554/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6555/// dimension. If the input vector comes from `vector.create_mask` for which
6556/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6557/// to fold shape_cast into create_mask.
6558///
6559/// BEFORE:
6560/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6561/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6562/// AFTER:
6563/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6564class ShapeCastCreateMaskFolderTrailingOneDim final
6565 : public OpRewritePattern<ShapeCastOp> {
6566public:
6567 using Base::Base;
6568
6569 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6570 PatternRewriter &rewriter) const override {
6571 Value shapeOpSrc = shapeOp->getOperand(0);
6572 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6573 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6574 if (!createMaskOp && !constantMaskOp)
6575 return failure();
6576
6577 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6578 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6579
6580 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6581 if (newVecType != shapeOpResTy)
6582 return failure();
6583
6584 auto numDimsToDrop =
6585 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6586
6587 // No unit dims to drop
6588 if (!numDimsToDrop)
6589 return failure();
6590
6591 if (createMaskOp) {
6592 auto maskOperands = createMaskOp.getOperands();
6593 auto numMaskOperands = maskOperands.size();
6594
6595 // Check every mask dim size to see whether it can be dropped
6596 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6597 --i) {
6598 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6599 if (!constant || (constant.value() != 1))
6600 return failure();
6601 }
6602 SmallVector<Value> newMaskOperands =
6603 maskOperands.drop_back(numDimsToDrop);
6604
6605 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6606 newMaskOperands);
6607 return success();
6608 }
6609
6610 if (constantMaskOp) {
6611 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6612 auto numMaskOperands = maskDimSizes.size();
6613
6614 // Check every mask dim size to see whether it can be dropped
6615 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6616 --i) {
6617 if (maskDimSizes[i] != 1)
6618 return failure();
6619 }
6620
6621 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6622 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6623 newMaskOperands);
6624 return success();
6625 }
6626
6627 return failure();
6628 }
6629};
6630
6631/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
6632class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6633public:
6634 using Base::Base;
6635
6636 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6637 PatternRewriter &rewriter) const override {
6638 auto broadcastOp =
6639 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6640 if (!broadcastOp)
6641 return failure();
6642
6643 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6644 bool srcIsScalar = !srcVectorType;
6645
6646 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6647 // Example
6648 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6649 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6650 // to
6651 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6652 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6653 if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
6654 BroadcastableToResult::Success) {
6655 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6656 shapeCastOp, dstVectorType, broadcastOp.getSource());
6657 return success();
6658 }
6659 return failure();
6660 }
6661};
6662
6663} // namespace
6664
6665void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6666 MLIRContext *context) {
6667 results
6668 .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6669 context);
6670}
6671
6672//===----------------------------------------------------------------------===//
6673// VectorBitCastOp
6674//===----------------------------------------------------------------------===//
6675
6676LogicalResult BitCastOp::verify() {
6677 auto sourceVectorType = getSourceVectorType();
6678 auto resultVectorType = getResultVectorType();
6679
6680 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6681 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6682 return emitOpError("dimension size mismatch at: ") << i;
6683 }
6684
6685 DataLayout dataLayout = DataLayout::closest(*this);
6686 auto sourceElementBits =
6687 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
6688 auto resultElementBits =
6689 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
6690
6691 if (sourceVectorType.getRank() == 0) {
6692 if (sourceElementBits != resultElementBits)
6693 return emitOpError("source/result bitwidth of the 0-D vector element "
6694 "types must be equal");
6695 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6696 resultElementBits * resultVectorType.getShape().back()) {
6697 return emitOpError(
6698 "source/result bitwidth of the minor 1-D vectors must be equal");
6699 }
6700
6701 return success();
6702}
6703
6704OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6705 // Nop cast.
6706 if (getSource().getType() == getResult().getType())
6707 return getSource();
6708
6709 // Canceling bitcasts.
6710 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6711 if (getResult().getType() == otherOp.getSource().getType())
6712 return otherOp.getSource();
6713
6714 setOperand(otherOp.getSource());
6715 return getResult();
6716 }
6717
6718 Attribute sourceConstant = adaptor.getSource();
6719 if (!sourceConstant)
6720 return {};
6721
6722 Type srcElemType = getSourceVectorType().getElementType();
6723 Type dstElemType = getResultVectorType().getElementType();
6724
6725 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6726 if (floatPack.isSplat()) {
6727 auto splat = floatPack.getSplatValue<FloatAttr>();
6728
6729 // Casting fp16 into fp32.
6730 if (srcElemType.isF16() && dstElemType.isF32()) {
6731 uint32_t bits = static_cast<uint32_t>(
6732 splat.getValue().bitcastToAPInt().getZExtValue());
6733 // Duplicate the 16-bit pattern.
6734 bits = (bits << 16) | (bits & 0xffff);
6735 APInt intBits(32, bits);
6736 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6737 return DenseElementsAttr::get(getResultVectorType(), floatBits);
6738 }
6739 }
6740 }
6741
6742 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6743 if (intPack.isSplat()) {
6744 auto splat = intPack.getSplatValue<IntegerAttr>();
6745
6746 if (llvm::isa<IntegerType>(dstElemType)) {
6747 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6748 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6749
6750 // Casting to a larger integer bit width.
6751 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6752 APInt intBits = splat.getValue().zext(dstBitWidth);
6753
6754 // Duplicate the lower width element.
6755 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6756 intBits = (intBits << srcBitWidth) | intBits;
6757 return DenseElementsAttr::get(getResultVectorType(), intBits);
6758 }
6759 }
6760 }
6761 }
6762
6763 return {};
6764}
6765
6766//===----------------------------------------------------------------------===//
6767// TypeCastOp
6768//===----------------------------------------------------------------------===//
6769
6770static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6771 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6772 SmallVector<int64_t, 8> res(memRefType.getShape());
6773 if (vectorType)
6774 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6775 return res;
6776}
6777
6778/// Build the canonical memRefType with a single vector.
6779/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6780void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6781 Value source) {
6782 result.addOperands(source);
6783 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
6784 VectorType vectorType =
6785 VectorType::get(extractShape(memRefType),
6787 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6788 memRefType.getMemorySpace()));
6789}
6790
6791LogicalResult TypeCastOp::verify() {
6792 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6793 if (!canonicalType.getLayout().isIdentity())
6794 return emitOpError("expects operand to be a memref with identity layout");
6795 if (!getResultMemRefType().getLayout().isIdentity())
6796 return emitOpError("expects result to be a memref with identity layout");
6797 if (getResultMemRefType().getMemorySpace() !=
6798 getMemRefType().getMemorySpace())
6799 return emitOpError("expects result in same memory space");
6800
6801 auto sourceType = getMemRefType();
6802 auto resultType = getResultMemRefType();
6803 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
6805 return emitOpError(
6806 "expects result and operand with same underlying scalar type: ")
6807 << resultType;
6808 if (extractShape(sourceType) != extractShape(resultType))
6809 return emitOpError(
6810 "expects concatenated result and operand shapes to be equal: ")
6811 << resultType;
6812 return success();
6813}
6814
6815//===----------------------------------------------------------------------===//
6816// TransposeOp
6817//===----------------------------------------------------------------------===//
6818
6819void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6820 Value vector, ArrayRef<int64_t> permutation) {
6821 VectorType vt = llvm::cast<VectorType>(vector.getType());
6822 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6823 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6824 for (unsigned i = 0; i < permutation.size(); ++i) {
6825 transposedShape[i] = vt.getShape()[permutation[i]];
6826 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6827 }
6828
6829 result.addOperands(vector);
6830 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6831 transposedScalableDims));
6832 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6833 builder.getDenseI64ArrayAttr(permutation));
6834}
6835
6836OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6837 // Eliminate splat constant transpose ops.
6838 if (auto splat =
6839 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6840 return splat.reshape(getResultVectorType());
6841
6842 // Eliminate poison transpose ops.
6843 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6844 return ub::PoisonAttr::get(getContext());
6845
6846 // Eliminate identity transposes, and more generally any transposes that
6847 // preserves the shape without permuting elements.
6848 //
6849 // Examples of what to fold:
6850 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6851 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6852 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6853 //
6854 // Example of what NOT to fold:
6855 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6856 //
6857 if (getSourceVectorType() == getResultVectorType() &&
6858 isOrderPreserving(*this))
6859 return getVector();
6860
6861 return {};
6862}
6863
6864LogicalResult vector::TransposeOp::verify() {
6865 VectorType vectorType = getSourceVectorType();
6866 VectorType resultType = getResultVectorType();
6867 int64_t rank = resultType.getRank();
6868 if (vectorType.getRank() != rank)
6869 return emitOpError("vector result rank mismatch: ") << rank;
6870 // Verify transposition array.
6871 ArrayRef<int64_t> perm = getPermutation();
6872 int64_t size = perm.size();
6873 if (rank != size)
6874 return emitOpError("transposition length mismatch: ") << size;
6875 SmallVector<bool, 8> seen(rank, false);
6876 for (const auto &ta : llvm::enumerate(perm)) {
6877 if (ta.value() < 0 || ta.value() >= rank)
6878 return emitOpError("transposition index out of range: ") << ta.value();
6879 if (seen[ta.value()])
6880 return emitOpError("duplicate position index: ") << ta.value();
6881 seen[ta.value()] = true;
6882 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6883 return emitOpError("dimension size mismatch at: ") << ta.value();
6884 }
6885 return success();
6886}
6887
6888std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6889 return llvm::to_vector<4>(getResultVectorType().getShape());
6890}
6891
6892void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6893 SetIntRangeFn setResultRanges) {
6894 setResultRanges(getResult(), argRanges.front());
6895}
6896
6897namespace {
6898
6899// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6900class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6901public:
6902 using Base::Base;
6903
6904 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6905 PatternRewriter &rewriter) const override {
6906 // Composes two permutations: result[i] = permutation1[permutation2[i]].
6907 auto composePermutations = [](ArrayRef<int64_t> permutation1,
6908 ArrayRef<int64_t> permutation2) {
6909 SmallVector<int64_t, 4> result;
6910 for (auto index : permutation2)
6911 result.push_back(permutation1[index]);
6912 return result;
6913 };
6914
6915 // Return if the input of 'transposeOp' is not defined by another transpose.
6916 vector::TransposeOp parentTransposeOp =
6917 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6918 if (!parentTransposeOp)
6919 return failure();
6920
6921 SmallVector<int64_t, 4> permutation = composePermutations(
6922 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6923 // Replace 'transposeOp' with a new transpose operation.
6924 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6925 transposeOp, transposeOp.getResult().getType(),
6926 parentTransposeOp.getVector(), permutation);
6927 return success();
6928 }
6929};
6930
6931/// Replace transpose(splat-like(v)) with broadcast(v)
6932class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6933public:
6934 using Base::Base;
6935
6936 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6937 PatternRewriter &rewriter) const override {
6938 Value splat = getScalarSplatSource(transposeOp.getVector());
6939 if (!splat)
6940 return failure();
6941
6942 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6943 transposeOp, transposeOp.getResultVectorType(), splat);
6944 return success();
6945 }
6946};
6947
6948/// Folds transpose(create_mask) into a new transposed create_mask.
6949class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6950public:
6951 using Base::Base;
6952
6953 LogicalResult matchAndRewrite(TransposeOp transpOp,
6954 PatternRewriter &rewriter) const override {
6955 Value transposeSrc = transpOp.getVector();
6956 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6957 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6958 if (!createMaskOp && !constantMaskOp)
6959 return failure();
6960
6961 // Get the transpose permutation and apply it to the vector.create_mask or
6962 // vector.constant_mask operands.
6963 ArrayRef<int64_t> permutation = transpOp.getPermutation();
6964
6965 if (createMaskOp) {
6966 auto maskOperands = createMaskOp.getOperands();
6967 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6968 applyPermutationToVector(newOperands, permutation);
6969
6970 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6971 transpOp, transpOp.getResultVectorType(), newOperands);
6972 return success();
6973 }
6974
6975 // ConstantMaskOp case.
6976 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6977 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6978
6979 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6980 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6981 return success();
6982 }
6983};
6984
6985/// Folds transpose(shape_cast) into a new shape_cast.
6986class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6987public:
6988 using Base::Base;
6989
6990 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6991 PatternRewriter &rewriter) const override {
6992 auto shapeCastOp =
6993 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6994 if (!shapeCastOp)
6995 return failure();
6996 if (!isOrderPreserving(transposeOp))
6997 return failure();
6998
6999 VectorType resultType = transposeOp.getType();
7000
7001 // We don't need to check isValidShapeCast at this point, because it is
7002 // guaranteed that merging the transpose into the the shape_cast is a valid
7003 // shape_cast, because the transpose just inserts/removes ones.
7004
7005 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
7006 shapeCastOp.getSource());
7007 return success();
7008 }
7009};
7010
7011/// Folds transpose(from_elements(...)) into a new from_elements with permuted
7012/// operands matching the transposed shape.
7013///
7014/// Example:
7015///
7016/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
7017/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
7018/// vector<3x2xi32>
7019///
7020/// becomes ->
7021///
7022/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
7023/// vector<3x2xi32>
7024///
7025class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
7026public:
7027 using Base::Base;
7028 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7029 PatternRewriter &rewriter) const override {
7030 auto fromElementsOp =
7031 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7032 if (!fromElementsOp)
7033 return failure();
7034
7035 VectorType srcTy = fromElementsOp.getDest().getType();
7036 VectorType dstTy = transposeOp.getType();
7037
7038 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7039 int64_t rank = srcTy.getRank();
7040
7041 // Build inverse permutation to map destination indices back to source.
7042 SmallVector<int64_t> inversePerm(rank, 0);
7043 for (int64_t i = 0; i < rank; ++i)
7044 inversePerm[permutation[i]] = i;
7045
7046 ArrayRef<int64_t> srcShape = srcTy.getShape();
7047 ArrayRef<int64_t> dstShape = dstTy.getShape();
7048 SmallVector<int64_t> srcIdx(rank, 0);
7049 SmallVector<int64_t> dstIdx(rank, 0);
7050 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
7051 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
7052
7053 auto elementsOld = fromElementsOp.getElements();
7054 SmallVector<Value> elementsNew;
7055 int64_t dstNumElements = dstTy.getNumElements();
7056 elementsNew.reserve(dstNumElements);
7057
7058 // For each element in destination row-major order, pick the corresponding
7059 // source element.
7060 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7061 // Pick the destination element index.
7062 dstIdx = delinearize(linearIdx, dstStrides);
7063 // Map the destination element index to the source element index.
7064 for (int64_t j = 0; j < rank; ++j)
7065 srcIdx[j] = dstIdx[inversePerm[j]];
7066 // Linearize the source element index.
7067 int64_t srcLin = linearize(srcIdx, srcStrides);
7068 // Add the source element to the new elements.
7069 elementsNew.push_back(elementsOld[srcLin]);
7070 }
7071
7072 rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
7073 elementsNew);
7074 return success();
7075 }
7076};
7077
7078/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
7079/// 'order preserving', where 'order preserving' means the flattened
7080/// inputs and outputs of the transpose have identical (numerical) values.
7081///
7082/// Example:
7083/// ```
7084/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
7085/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
7086/// to vector<8x1xi32>
7087/// ```
7088/// can be rewritten as the equivalent
7089/// ```
7090/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
7091/// ```
7092/// The algorithm works by partitioning dimensions into groups that can be
7093/// locally permuted while preserving order, and checks that the transpose
7094/// only permutes within these groups.
7095///
7096/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
7097/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
7098/// broadcasting from 1x1x4x1x1x7.
7099/// ^^^ ^ ^^^ ^
7100/// groups: 0 1 2 3
7101/// Order preserving permutations for this example are ones that only permute
7102/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
7103class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
7104public:
7105 using Base::Base;
7106 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7107 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7108
7109 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7110 PatternRewriter &rewriter) const override {
7111
7112 vector::BroadcastOp broadcast =
7113 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
7114 if (!broadcast) {
7115 return rewriter.notifyMatchFailure(transpose,
7116 "not preceded by a broadcast");
7117 }
7118
7119 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
7120 VectorType outputType = transpose.getResultVectorType();
7121
7122 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
7123 bool inputIsScalar = !inputType;
7124 if (inputIsScalar) {
7125 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7126 broadcast.getSource());
7127 return success();
7128 }
7129
7130 ArrayRef<int64_t> permutation = transpose.getPermutation();
7131 ArrayRef<int64_t> inputShape = inputType.getShape();
7132 int64_t inputRank = inputType.getRank();
7133 int64_t outputRank = transpose.getType().getRank();
7134 int64_t deltaRank = outputRank - inputRank;
7135
7136 int low = 0;
7137 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7138 bool notOne = inputShape[inputIndex] != 1;
7139 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7140 bool groupEndFound = notOne || prevNotOne;
7141 if (groupEndFound) {
7142 int high = inputIndex + deltaRank;
7143 // Return failure if not all permutation destinations for indices in
7144 // [low, high) are in [low, high), i.e. the permutation is not local to
7145 // the group.
7146 for (int i = low; i < high; ++i) {
7147 if (permutation[i] < low || permutation[i] >= high) {
7148 return rewriter.notifyMatchFailure(
7149 transpose, "permutation not local to group");
7150 }
7151 }
7152 low = high;
7153 }
7154 }
7155
7156 // We don't need to check the final group [low, outputRank) because if it is
7157 // not locally bound, there must be a preceding group that already failed
7158 // the check (impossible to have just 1 non-locally bound group).
7159
7160 // The preceding logic also ensures that at this point, the output of the
7161 // transpose is definitely broadcastable from the input shape, assert so:
7162 assert(vector::isBroadcastableTo(inputType, outputType) ==
7163 vector::BroadcastableToResult::Success &&
7164 "not broadcastable directly to transpose output");
7165
7166 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7167 broadcast.getSource());
7168
7169 return success();
7170 }
7171};
7172
7173} // namespace
7174
7175void vector::TransposeOp::getCanonicalizationPatterns(
7176 RewritePatternSet &results, MLIRContext *context) {
7177 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7178 FoldTransposeSplat, FoldTransposeFromElements,
7179 FoldTransposeBroadcast>(context);
7180}
7181
7182//===----------------------------------------------------------------------===//
7183// ConstantMaskOp
7184//===----------------------------------------------------------------------===//
7185
7186void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
7187 VectorType type, ConstantMaskKind kind) {
7188 assert(kind == ConstantMaskKind::AllTrue ||
7189 kind == ConstantMaskKind::AllFalse);
7190 build(builder, result, type,
7191 kind == ConstantMaskKind::AllTrue
7192 ? type.getShape()
7193 : SmallVector<int64_t>(type.getRank(), 0));
7194}
7195
7196LogicalResult ConstantMaskOp::verify() {
7197 auto resultType = llvm::cast<VectorType>(getResult().getType());
7198 // Check the corner case of 0-D vectors first.
7199 if (resultType.getRank() == 0) {
7200 if (getMaskDimSizes().size() != 1)
7201 return emitError("array attr must have length 1 for 0-D vectors");
7202 auto dim = getMaskDimSizes()[0];
7203 if (dim != 0 && dim != 1)
7204 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
7205 return success();
7206 }
7207
7208 // Verify that array attr size matches the rank of the vector result.
7209 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
7210 return emitOpError(
7211 "must specify array attr of size equal vector result rank");
7212 // Verify that each array attr element is in bounds of corresponding vector
7213 // result dimension size.
7214 auto resultShape = resultType.getShape();
7215 auto resultScalableDims = resultType.getScalableDims();
7216 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7217 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7218 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7219 return emitOpError(
7220 "array attr of size out of bounds of vector result dimension size");
7221 if (resultScalableDims[index] && maskDimSize != 0 &&
7222 maskDimSize != resultShape[index])
7223 return emitOpError(
7224 "only supports 'none set' or 'all set' scalable dimensions");
7225 }
7226 // Verify that if one mask dim size is zero, they all should be zero (because
7227 // the mask region is a conjunction of each mask dimension interval).
7228 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7229 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
7230 if (anyZeros && !allZeros)
7231 return emitOpError("expected all mask dim sizes to be zeros, "
7232 "as a result of conjunction with zero mask dim");
7233 return success();
7234}
7235
7236bool ConstantMaskOp::isAllOnesMask() {
7237 auto resultType = getVectorType();
7238 // Check the corner case of 0-D vectors first.
7239 if (resultType.getRank() == 0) {
7240 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7241 return getMaskDimSizes()[0] == 1;
7242 }
7243 for (const auto [resultSize, maskDimSize] :
7244 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7245 if (maskDimSize < resultSize)
7246 return false;
7247 }
7248 return true;
7249}
7250
7251OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7252 ArrayRef<int64_t> bounds = getMaskDimSizes();
7253 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7254
7255 auto createBoolSplat = [&](bool x) {
7256 return SplatElementsAttr::get(getVectorType(),
7258 };
7259
7260 // Check the corner case of 0-D vectors first.
7261 if (vectorSizes.empty()) {
7262 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7263 return createBoolSplat(bounds[0] == 1);
7264 }
7265 // Fold vector.constant_mask to splat if possible.
7266 if (bounds == vectorSizes)
7267 return createBoolSplat(true);
7268 if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7269 return createBoolSplat(false);
7270 return OpFoldResult();
7271}
7272
7273//===----------------------------------------------------------------------===//
7274// CreateMaskOp
7275//===----------------------------------------------------------------------===//
7276
7277void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7278 VectorType type,
7279 ArrayRef<OpFoldResult> mixedOperands) {
7280 SmallVector<Value> operands =
7281 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7282 build(builder, result, type, operands);
7283}
7284
7285LogicalResult CreateMaskOp::verify() {
7286 auto vectorType = llvm::cast<VectorType>(getResult().getType());
7287 // Verify that an operand was specified for each result vector each dimension.
7288 if (vectorType.getRank() == 0) {
7289 if (getNumOperands() != 1)
7290 return emitOpError(
7291 "must specify exactly one operand for 0-D create_mask");
7292 } else if (getNumOperands() !=
7293 llvm::cast<VectorType>(getResult().getType()).getRank()) {
7294 return emitOpError(
7295 "must specify an operand for each result vector dimension");
7296 }
7297 return success();
7298}
7299
7300namespace {
7301
7302/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7303///
7304/// Ex 1:
7305/// %c2 = arith.constant 2 : index
7306/// %c3 = arith.constant 3 : index
7307/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7308/// Becomes:
7309/// vector.constant_mask [3, 2] : vector<4x3xi1>
7310///
7311/// Ex 2:
7312/// %c_neg_1 = arith.constant -1 : index
7313/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7314/// becomes:
7315/// vector.constant_mask [0] : vector<[8]xi1>
7316///
7317/// Ex 3:
7318/// %c8 = arith.constant 8 : index
7319/// %c16 = arith.constant 16 : index
7320/// %0 = vector.vscale
7321/// %1 = arith.muli %0, %c16 : index
7322/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7323/// becomes:
7324/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7325class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7326public:
7327 using Base::Base;
7328
7329 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7330 PatternRewriter &rewriter) const override {
7331 VectorType maskType = createMaskOp.getVectorType();
7332 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7333 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7334
7335 // Special case: Rank zero shape.
7336 constexpr std::array<int64_t, 1> rankZeroShape{1};
7337 constexpr std::array<bool, 1> rankZeroScalableDims{false};
7338 if (maskType.getRank() == 0) {
7339 maskTypeDimSizes = rankZeroShape;
7340 maskTypeDimScalableFlags = rankZeroScalableDims;
7341 }
7342
7343 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7344 // collect the `constantDims` (for the ConstantMaskOp).
7345 SmallVector<int64_t, 4> constantDims;
7346 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7347 if (auto intSize = getConstantIntValue(dimSize)) {
7348 // Constant value.
7349 // If the mask dim is non-scalable this can be any value.
7350 // If the mask dim is scalable only zero (all-false) is supported.
7351 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7352 return failure();
7353 constantDims.push_back(*intSize);
7354 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7355 // Constant vscale multiple (e.g. 4 x vscale).
7356 // Must be all-true to fold to a ConstantMask.
7357 if (vscaleMultiplier < maskTypeDimSizes[i])
7358 return failure();
7359 constantDims.push_back(*vscaleMultiplier);
7360 } else {
7361 return failure();
7362 }
7363 }
7364
7365 // Clamp values to constant_mask bounds.
7366 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7367 value = std::clamp<int64_t>(value, 0, maskDimSize);
7368
7369 // If one of dim sizes is zero, set all dims to zero.
7370 if (llvm::is_contained(constantDims, 0))
7371 constantDims.assign(constantDims.size(), 0);
7372
7373 // Replace 'createMaskOp' with ConstantMaskOp.
7374 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7375 constantDims);
7376 return success();
7377 }
7378};
7379
7380} // namespace
7381
7382void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7383 MLIRContext *context) {
7384 results.add<CreateMaskFolder>(context);
7385}
7386
7387//===----------------------------------------------------------------------===//
7388// MaskOp
7389//===----------------------------------------------------------------------===//
7390
7391void MaskOp::build(
7392 OpBuilder &builder, OperationState &result, Value mask,
7393 Operation *maskableOp,
7394 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7395 assert(maskRegionBuilder &&
7396 "builder callback for 'maskRegion' must be present");
7397
7398 result.addOperands(mask);
7399 OpBuilder::InsertionGuard guard(builder);
7400 Region *maskRegion = result.addRegion();
7401 builder.createBlock(maskRegion);
7402 maskRegionBuilder(builder, maskableOp);
7403}
7404
7405void MaskOp::build(
7406 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7407 Value mask, Operation *maskableOp,
7408 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7409 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7410 maskRegionBuilder);
7411}
7412
7413void MaskOp::build(
7414 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7415 Value mask, Value passthru, Operation *maskableOp,
7416 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7417 build(builder, result, mask, maskableOp, maskRegionBuilder);
7418 if (passthru)
7419 result.addOperands(passthru);
7420 result.addTypes(resultTypes);
7421}
7422
7423ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7424 // Create the op region.
7425 result.regions.reserve(1);
7426 Region &maskRegion = *result.addRegion();
7427
7428 auto &builder = parser.getBuilder();
7429
7430 // Parse all the operands.
7431 OpAsmParser::UnresolvedOperand mask;
7432 if (parser.parseOperand(mask))
7433 return failure();
7434
7435 // Optional passthru operand.
7436 OpAsmParser::UnresolvedOperand passthru;
7437 ParseResult parsePassthru = parser.parseOptionalComma();
7438 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7439 return failure();
7440
7441 // Parse op region.
7442 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7443 return failure();
7444
7445 MaskOp::ensureTerminator(maskRegion, builder, result.location);
7446
7447 // Parse the optional attribute list.
7448 if (parser.parseOptionalAttrDict(result.attributes))
7449 return failure();
7450
7451 // Parse all the types.
7452 Type maskType;
7453 if (parser.parseColonType(maskType))
7454 return failure();
7455
7456 SmallVector<Type> resultTypes;
7457 if (parser.parseOptionalArrowTypeList(resultTypes))
7458 return failure();
7459 result.types.append(resultTypes);
7460
7461 // Resolve operands.
7462 if (parser.resolveOperand(mask, maskType, result.operands))
7463 return failure();
7464
7465 if (parsePassthru.succeeded()) {
7466 if (resultTypes.empty())
7467 return parser.emitError(
7468 parser.getNameLoc(),
7469 "expects a result if passthru operand is provided");
7470
7471 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7472 return failure();
7473 }
7474
7475 return success();
7476}
7477
7478void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7479 p << " " << getMask();
7480 if (getPassthru())
7481 p << ", " << getPassthru();
7482
7483 // Print single masked operation and skip terminator.
7484 p << " { ";
7485 Block *singleBlock = &getMaskRegion().getBlocks().front();
7486 if (singleBlock && !singleBlock->getOperations().empty())
7487 p.printCustomOrGenericOp(&singleBlock->front());
7488 p << " }";
7489
7490 p.printOptionalAttrDict(getOperation()->getAttrs());
7491
7492 p << " : " << getMask().getType();
7493 if (getNumResults() > 0)
7494 p << " -> " << getResultTypes();
7495}
7496
7497void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7498 // 1. For an empty `vector.mask`, create a default terminator.
7499 if (region.empty() || region.front().empty()) {
7500 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7501 MaskOp>::ensureTerminator(region, builder, loc);
7502 return;
7503 }
7504
7505 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7506 Block &block = region.front();
7507 if (isa<vector::YieldOp>(block.back()))
7508 return;
7509
7510 // 3. For a non-empty `vector.mask` without an explicit terminator:
7511
7512 // Create default terminator if the number of masked operations is not
7513 // one. This case will trigger a verification failure.
7514 if (block.getOperations().size() != 1) {
7515 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7516 MaskOp>::ensureTerminator(region, builder, loc);
7517 return;
7518 }
7519
7520 // Create a terminator that yields the results from the masked operation.
7521 OpBuilder opBuilder(builder.getContext());
7522 Operation *maskedOp = &block.front();
7523 opBuilder.setInsertionPointToEnd(&block);
7524 vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7525}
7526
7527LogicalResult MaskOp::verify() {
7528 // Structural checks.
7529 Block &block = getMaskRegion().getBlocks().front();
7530 if (block.getOperations().empty())
7531 return emitOpError("expects a terminator within the mask region");
7532
7533 unsigned numMaskRegionOps = block.getOperations().size();
7534 if (numMaskRegionOps > 2)
7535 return emitOpError("expects only one operation to mask");
7536
7537 // Terminator checks.
7538 auto terminator = dyn_cast<vector::YieldOp>(block.back());
7539 if (!terminator)
7540 return emitOpError("expects a terminator within the mask region");
7541
7542 if (terminator->getNumOperands() != getNumResults())
7543 return emitOpError(
7544 "expects number of results to match mask region yielded values");
7545
7546 // Empty vector.mask. Nothing else to check.
7547 if (numMaskRegionOps == 1)
7548 return success();
7549
7550 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7551 if (!maskableOp)
7552 return emitOpError("expects a MaskableOpInterface within the mask region");
7553
7554 // Result checks.
7555 if (maskableOp->getNumResults() != getNumResults())
7556 return emitOpError("expects number of results to match maskable operation "
7557 "number of results");
7558
7559 if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7560 return emitOpError("expects all the results from the MaskableOpInterface "
7561 "to match all the values returned by the terminator");
7562
7563 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7564 return emitOpError(
7565 "expects result type to match maskable operation result type");
7566
7567 if (llvm::count_if(maskableOp->getResultTypes(),
7568 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7569 return emitOpError("multiple vector results not supported");
7570
7571 // Mask checks.
7572 Type expectedMaskType = maskableOp.getExpectedMaskType();
7573 if (getMask().getType() != expectedMaskType)
7574 return emitOpError("expects a ")
7575 << expectedMaskType << " mask for the maskable operation";
7576
7577 // Passthru checks.
7578 Value passthru = getPassthru();
7579 if (passthru) {
7580 if (!maskableOp.supportsPassthru())
7581 return emitOpError(
7582 "doesn't expect a passthru argument for this maskable operation");
7583
7584 if (maskableOp->getNumResults() != 1)
7585 return emitOpError("expects result when passthru argument is provided");
7586
7587 if (passthru.getType() != maskableOp->getResultTypes()[0])
7588 return emitOpError("expects passthru type to match result type");
7589 }
7590
7591 return success();
7592}
7593
7594/// Folds empty `vector.mask` with no passthru operand and with or without
7595/// return values. For example:
7596///
7597/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7598/// vector<8xi1> -> vector<8xf32>
7599/// %1 = user_op %0 : vector<8xf32>
7600///
7601/// becomes:
7602///
7603/// %0 = user_op %a : vector<8xf32>
7604///
7605/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7606/// as it requires creating new operations.
7607
7608static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7609 SmallVectorImpl<OpFoldResult> &results) {
7610 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7611 return failure();
7612
7613 Block *block = maskOp.getMaskBlock();
7614 auto terminator = cast<vector::YieldOp>(block->front());
7615 if (terminator.getNumOperands() == 0) {
7616 // `vector.mask` has no results, just remove the `vector.mask`.
7617 return success();
7618 }
7619
7620 // `vector.mask` has results, propagate the results.
7621 llvm::append_range(results, terminator.getOperands());
7622 return success();
7623}
7624
7625LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7626 SmallVectorImpl<OpFoldResult> &results) {
7627 if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7628 return success();
7629
7630 MaskFormat maskFormat = getMaskFormat(getMask());
7631 if (maskFormat != MaskFormat::AllTrue)
7632 return failure();
7633
7634 // Move maskable operation outside of the `vector.mask` region.
7635 Operation *maskableOp = getMaskableOp();
7636 maskableOp->dropAllUses();
7637 maskableOp->moveBefore(getOperation());
7638
7639 llvm::append_range(results, maskableOp->getResults());
7640 return success();
7641}
7642
7643/// Canonialize empty `vector.mask` operations that can't be handled in
7644/// `VectorMask::fold` as they require creating new operations.
7645///
7646/// Example 1: Empty `vector.mask` with passthru operand.
7647///
7648/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7649/// vector<8xi1> -> vector<8xf32>
7650///
7651/// becomes:
7652///
7653/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7654///
7655class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7656 using Base::Base;
7657
7658 LogicalResult matchAndRewrite(MaskOp maskOp,
7659 PatternRewriter &rewriter) const override {
7660 if (!maskOp.isEmpty())
7661 return failure();
7662
7663 if (!maskOp.hasPassthru())
7664 return failure();
7665
7666 Block *block = maskOp.getMaskBlock();
7667 auto terminator = cast<vector::YieldOp>(block->front());
7668 assert(terminator.getNumOperands() == 1 &&
7669 "expected one result when passthru is provided");
7670
7671 rewriter.replaceOpWithNewOp<arith::SelectOp>(
7672 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7673 terminator.getOperand(0), maskOp.getPassthru());
7674
7675 return success();
7676 }
7677};
7678
7679void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7680 MLIRContext *context) {
7681 results.add<CanonializeEmptyMaskOp>(context);
7682}
7683
7684// MaskingOpInterface definitions.
7685
7686/// Returns the operation masked by this 'vector.mask'.
7687Operation *MaskOp::getMaskableOp() {
7688 Block *block = getMaskBlock();
7689 if (block->getOperations().size() < 2)
7690 return nullptr;
7691
7692 return &block->front();
7693}
7694
7695/// Returns true if 'vector.mask' has a passthru value.
7696bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
7697
7698//===----------------------------------------------------------------------===//
7699// ScanOp
7700//===----------------------------------------------------------------------===//
7701
7702LogicalResult ScanOp::verify() {
7703 VectorType srcType = getSourceType();
7704 VectorType initialType = getInitialValueType();
7705 // Check reduction dimension < rank.
7706 int64_t srcRank = srcType.getRank();
7707 int64_t reductionDim = getReductionDim();
7708 if (reductionDim >= srcRank)
7709 return emitOpError("reduction dimension ")
7710 << reductionDim << " has to be less than " << srcRank;
7711
7712 // Check that rank(initial_value) = rank(src) - 1.
7713 int64_t initialValueRank = initialType.getRank();
7714 if (initialValueRank != srcRank - 1)
7715 return emitOpError("initial value rank ")
7716 << initialValueRank << " has to be equal to " << srcRank - 1;
7717
7718 // Check shapes of initial value and src.
7719 ArrayRef<int64_t> srcShape = srcType.getShape();
7720 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7721 SmallVector<int64_t> expectedShape;
7722 for (int i = 0; i < srcRank; i++) {
7723 if (i != reductionDim)
7724 expectedShape.push_back(srcShape[i]);
7725 }
7726 if (!llvm::equal(initialValueShapes, expectedShape)) {
7727 return emitOpError("incompatible input/initial value shapes");
7728 }
7729
7730 // Verify supported reduction kind.
7731 Type eltType = getDestType().getElementType();
7732 if (!isSupportedCombiningKind(getKind(), eltType))
7733 return emitOpError("unsupported reduction type ")
7734 << eltType << " for kind '" << stringifyCombiningKind(getKind())
7735 << "'";
7736
7737 return success();
7738}
7739
7741 RewritePatternSet &patterns, PatternBenefit benefit) {
7742 patterns
7743 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7744 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7745 StridedSliceConstantMaskFolder, TransposeFolder>(
7746 patterns.getContext(), benefit);
7747}
7748
7749Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
7750 CombiningKind kind, Value v1, Value acc,
7751 arith::FastMathFlagsAttr fastmath,
7752 Value mask) {
7753 Type t1 = getElementTypeOrSelf(v1.getType());
7754 Type tAcc = getElementTypeOrSelf(acc.getType());
7755 Value result;
7756
7757 switch (kind) {
7758 case CombiningKind::ADD:
7759 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7760 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
7761 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7762 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7763 else
7764 llvm_unreachable("invalid value types for ADD reduction");
7765 break;
7766 case CombiningKind::AND:
7767 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7768 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
7769 break;
7770 case CombiningKind::MAXNUMF:
7771 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7772 "expected float values");
7773 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7774 break;
7775 case CombiningKind::MAXIMUMF:
7776 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7777 "expected float values");
7778 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7779 break;
7780 case CombiningKind::MINNUMF:
7781 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7782 "expected float values");
7783 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7784 break;
7785 case CombiningKind::MINIMUMF:
7786 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7787 "expected float values");
7788 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7789 break;
7790 case CombiningKind::MAXSI:
7791 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7792 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7793 break;
7794 case CombiningKind::MINSI:
7795 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7796 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7797 break;
7798 case CombiningKind::MAXUI:
7799 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7800 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7801 break;
7802 case CombiningKind::MINUI:
7803 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7804 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7805 break;
7806 case CombiningKind::MUL:
7807 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7808 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
7809 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7810 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7811 else
7812 llvm_unreachable("invalid value types for MUL reduction");
7813 break;
7814 case CombiningKind::OR:
7815 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7816 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
7817 break;
7818 case CombiningKind::XOR:
7819 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7820 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7821 break;
7822 };
7823
7824 assert(result && "unknown CombiningKind");
7825 return selectPassthru(b, mask, result, acc);
7826}
7827
7828//===----------------------------------------------------------------------===//
7829// StepOp
7830//===----------------------------------------------------------------------===//
7831
7832void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7833 SetIntRangeFn setResultRanges) {
7834 auto resultType = cast<VectorType>(getType());
7835 if (resultType.isScalable()) {
7836 return;
7837 }
7838 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7839 APInt zero(bitwidth, 0);
7840 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7841 ConstantIntRanges result = {zero, high, zero, high};
7842 setResultRanges(getResult(), result);
7843}
7844
7845namespace {
7846
7847/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
7848/// constant large enough such that the result is the same at all indices.
7849///
7850/// For example, rewrite the 'greater than' comparison below,
7851///
7852/// ```mlir
7853/// %cst = arith.constant dense<7> : vector<3xindex>
7854/// %stp = vector.step : vector<3xindex>
7855/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
7856/// ```
7857///
7858/// as,
7859///
7860/// ```mlir
7861/// %out = arith.constant dense<false> : vector<3xi1>.
7862/// ```
7863///
7864/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
7865/// is false at ALL indices we fold. If the constant was 1, then
7866/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
7867/// conservatively preferring the 'compact' vector.step representation.
7868///
7869/// Note: this folder only works for the case where the constant (`%cst` above)
7870/// is the second operand of the comparison. The arith.cmpi canonicalizer will
7871/// ensure that constants are always second (on the right).
7872struct StepCompareFolder : public OpRewritePattern<StepOp> {
7873 using Base::Base;
7874
7875 LogicalResult matchAndRewrite(StepOp stepOp,
7876 PatternRewriter &rewriter) const override {
7877 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7878
7879 for (OpOperand &use : stepOp.getResult().getUses()) {
7880 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7881 if (!cmpiOp)
7882 continue;
7883
7884 // arith.cmpi canonicalizer makes constants final operands.
7885 const unsigned stepOperandNumber = use.getOperandNumber();
7886 if (stepOperandNumber != 0)
7887 continue;
7888
7889 // Check that operand 1 is a constant.
7890 unsigned constOperandNumber = 1;
7891 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7892 std::optional<int64_t> maybeConstValue =
7893 getConstantIntValue(otherOperand);
7894 if (!maybeConstValue.has_value())
7895 continue;
7896
7897 int64_t constValue = maybeConstValue.value();
7898 arith::CmpIPredicate pred = cmpiOp.getPredicate();
7899
7900 auto maybeSplat = [&]() -> std::optional<bool> {
7901 // Handle ult (unsigned less than) and uge (unsigned greater equal).
7902 if ((pred == arith::CmpIPredicate::ult ||
7903 pred == arith::CmpIPredicate::uge) &&
7904 stepSize <= constValue)
7905 return pred == arith::CmpIPredicate::ult;
7906
7907 // Handle ule and ugt.
7908 if ((pred == arith::CmpIPredicate::ule ||
7909 pred == arith::CmpIPredicate::ugt) &&
7910 stepSize - 1 <= constValue) {
7911 return pred == arith::CmpIPredicate::ule;
7912 }
7913
7914 // Handle eq and ne.
7915 if ((pred == arith::CmpIPredicate::eq ||
7916 pred == arith::CmpIPredicate::ne) &&
7917 stepSize <= constValue)
7918 return pred == arith::CmpIPredicate::ne;
7919
7920 return std::nullopt;
7921 }();
7922
7923 if (!maybeSplat.has_value())
7924 continue;
7925
7926 rewriter.setInsertionPointAfter(cmpiOp);
7927
7928 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7929 if (!type)
7930 continue;
7931
7932 auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
7933 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7934 type, boolAttr);
7935
7936 rewriter.replaceOp(cmpiOp, splat);
7937 return success();
7938 }
7939
7940 return failure();
7941 }
7942};
7943} // namespace
7944
7945void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
7946 MLIRContext *context) {
7947 results.add<StepCompareFolder>(context);
7948}
7949
7950//===----------------------------------------------------------------------===//
7951// Vector Masking Utilities
7952//===----------------------------------------------------------------------===//
7953
7954/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
7955/// as masked operation.
7956void mlir::vector::createMaskOpRegion(OpBuilder &builder,
7957 Operation *maskableOp) {
7958 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
7959 Block *insBlock = builder.getInsertionBlock();
7960 // Create a block and move the op to that block.
7961 insBlock->getOperations().splice(
7962 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
7963 YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
7964}
7965
7966/// Creates a vector.mask operation around a maskable operation. Returns the
7967/// vector.mask operation if the mask provided is valid. Otherwise, returns
7968/// the maskable operation itself.
7969Operation *mlir::vector::maskOperation(OpBuilder &builder,
7970 Operation *maskableOp, Value mask,
7971 Value passthru) {
7972 if (!mask)
7973 return maskableOp;
7974 if (passthru)
7975 return MaskOp::create(builder, maskableOp->getLoc(),
7976 maskableOp->getResultTypes(), mask, passthru,
7977 maskableOp, createMaskOpRegion);
7978 return MaskOp::create(builder, maskableOp->getLoc(),
7979 maskableOp->getResultTypes(), mask, maskableOp,
7981}
7982
7983/// Creates a vector select operation that picks values from `newValue` or
7984/// `passthru` for each result vector lane based on `mask`. This utility is used
7985/// to propagate the pass-thru value of vector.mask or for cases where only the
7986/// pass-thru value propagation is needed. VP intrinsics do not support
7987/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
7988/// usually able to match op + select patterns and fold them into a native
7989/// target instructions.
7990Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
7991 Value newValue, Value passthru) {
7992 if (!mask)
7993 return newValue;
7994
7995 return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
7996 mask, newValue, passthru);
7997}
7998
7999//===----------------------------------------------------------------------===//
8000// TableGen'd op method definitions
8001//===----------------------------------------------------------------------===//
8002
8003#define GET_ATTRDEF_CLASSES
8004#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8005
8006#define GET_OP_CLASSES
8007#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
b getContext())
auto load
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition VectorOps.cpp:72
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
Definition VectorOps.cpp:62
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
#define mul(a, b)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
bool empty()
Definition Block.h:158
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
iterator begin()
Definition Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isF32() const
Definition Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:578
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition VectorOps.h:64
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition VectorOps.h:72
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const