MLIR 23.0.0git
VectorOps.cpp
Go to the documentation of this file.
1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
15
27#include "mlir/IR/AffineExpr.h"
28#include "mlir/IR/AffineMap.h"
29#include "mlir/IR/Builders.h"
33#include "mlir/IR/IRMapping.h"
37#include "mlir/IR/ValueRange.h"
40#include "mlir/Support/LLVM.h"
42#include "llvm/ADT/ArrayRef.h"
43#include "llvm/ADT/STLExtras.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/StringSet.h"
47#include "llvm/ADT/TypeSwitch.h"
48#include "llvm/Support/Casting.h"
49
50#include <cassert>
51#include <cstdint>
52#include <numeric>
53
54#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
55// Pull in all enum type and utility function definitions.
56#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
57
58using namespace mlir;
59using namespace mlir::vector;
60
61/// Helper enum to classify mask value.
62enum class MaskFormat {
66};
67
68/// Helper method to classify a mask value. Currently, the method
69/// looks "under the hood" of a constant value with dense attributes
70/// and a constant mask operation (since the client may be called at
71/// various stages during progressive lowering).
73 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
74 // Inspect constant dense values. We count up for bits that
75 // are set, count down for bits that are cleared, and bail
76 // when a mix is detected.
77 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
78 int64_t val = 0;
79 for (bool b : denseElts.getValues<bool>())
80 if (b && val >= 0)
81 val++;
82 else if (!b && val <= 0)
83 val--;
84 else
86 if (val > 0)
88 if (val < 0)
90 }
91 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
92 // Inspect constant mask index. If the index exceeds the
93 // dimension size, all bits are set. If the index is zero
94 // or less, no bits are set.
95 ArrayRef<int64_t> masks = m.getMaskDimSizes();
96 auto shape = m.getType().getShape();
97 bool allTrue = true;
98 bool allFalse = true;
99 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
100 if (maskIdx < dimSize)
101 allTrue = false;
102 if (maskIdx > 0)
103 allFalse = false;
104 }
105 if (allTrue)
106 return MaskFormat::AllTrue;
107 if (allFalse)
109 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
110 // Finds all-false create_masks. An all-true create_mask requires all
111 // dims to be constants, so that'll be folded to a constant_mask, then
112 // detected in the constant_mask case.
113 auto maskOperands = m.getOperands();
114 for (Value operand : maskOperands) {
115 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
116 int64_t dimSize =
117 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
118 if (dimSize <= 0)
120 }
121 }
122 return MaskFormat::Unknown;
123 }
124 return MaskFormat::Unknown;
125}
126
127/// Default callback to build a region with a 'vector.yield' terminator with no
128/// arguments.
130 vector::YieldOp::create(builder, loc);
131}
132
133// Helper for verifying combining kinds in contractions and reductions.
134static bool isSupportedCombiningKind(CombiningKind combiningKind,
135 Type elementType) {
136 switch (combiningKind) {
137 case CombiningKind::ADD:
138 case CombiningKind::MUL:
139 return elementType.isIntOrIndexOrFloat();
140 case CombiningKind::MINUI:
141 case CombiningKind::MINSI:
142 case CombiningKind::MAXUI:
143 case CombiningKind::MAXSI:
144 case CombiningKind::AND:
145 case CombiningKind::OR:
146 case CombiningKind::XOR:
147 return elementType.isIntOrIndex();
148 case CombiningKind::MINNUMF:
149 case CombiningKind::MAXNUMF:
150 case CombiningKind::MINIMUMF:
151 case CombiningKind::MAXIMUMF:
152 return llvm::isa<FloatType>(elementType);
153 }
154 return false;
155}
156
157/// Returns the effective rank of the vector to read/write for Xfer Ops
158///
159/// When the element type of the shaped type is _a scalar_, this will simply
160/// return the rank of the vector ( the result for xfer_read or the value to
161/// store for xfer_write).
162///
163/// When the element type of the base shaped type is _a vector_, returns the
164/// difference between the original vector type and the element type of the
165/// shaped type.
166///
167/// EXAMPLE 1 (element type is _a scalar_):
168/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
169/// - shapedType.getElementType() = f32 (rank 0)
170/// - vectorType.getRank() = 2
171/// - Result = 2 - 0 = 2
172///
173/// EXAMPLE 2 (element type is _a vector_):
174/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
175/// - shapedType.getElementType() = vector<20xf32> (rank 1)
176/// - vectorType.getRank() = 1
177/// - Result = 1 - 1 = 0
178///
179/// This is used to determine the number of minor dimensions for identity maps
180/// in vector transfer Ops.
181static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
182 VectorType vectorType) {
183 unsigned elementVectorRank = 0;
184 VectorType elementVectorType =
185 llvm::dyn_cast<VectorType>(shapedType.getElementType());
186 if (elementVectorType)
187 elementVectorRank += elementVectorType.getRank();
188 return vectorType.getRank() - elementVectorRank;
189}
190
192 VectorType vectorType) {
193 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
194 // TODO: replace once we have 0-d vectors.
195 if (shapedType.getRank() == 0 &&
196 vectorType.getShape() == ArrayRef<int64_t>{1})
197 return AffineMap::get(
198 /*numDims=*/0, /*numSymbols=*/0,
199 getAffineConstantExpr(0, shapedType.getContext()));
201 shapedType.getRank(),
202 getEffectiveVectorRankForXferOp(shapedType, vectorType),
203 shapedType.getContext());
204}
205
206/// Check if `write` is of a constant splat and the masked `read` is padded with
207/// the same splat value -- meaning it could be the same value as the initial
208/// constant splat.
209static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
210 vector::TransferReadOp read) {
211 auto readMask = read.getMask();
212 auto writeMask = write.getMask();
213 // Check if the masks are consistent. The splat value could be the same if the
214 // read is masked (and padded with the splat value), and the write is unmasked
215 // or has the same mask. Note this does not allow the case where the write is
216 // masked and the read is unmasked, as then the read could be of more elements
217 // than the write (which may not be the same value).
218 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
219 if (!couldBeSameSplat)
220 return false;
221 // Check for constant splat (as the source of the write).
222 DenseElementsAttr splatAttr;
223 if (!matchPattern(write.getVector(),
224 m_Constant<DenseElementsAttr>(&splatAttr)) ||
225 !splatAttr.isSplat()) {
226 return false;
227 }
228 // The padding of the read and the constant splat value must be the same.
229 Attribute padAttr;
230 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
231 return false;
232 return padAttr == splatAttr.getSplatValue<Attribute>();
233}
234
235bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
236 vector::TransferReadOp read) {
237 return !defWrite.hasOutOfBoundsDim() &&
238 defWrite.getIndices() == read.getIndices() &&
239 defWrite.getVectorType() == read.getVectorType() &&
240 defWrite.getPermutationMap() == read.getPermutationMap() &&
241 ((!defWrite.getMask() && !read.getMask()) ||
243}
244
245bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
246 vector::TransferWriteOp priorWrite) {
247 return priorWrite.getIndices() == write.getIndices() &&
248 priorWrite.getMask() == write.getMask() &&
249 priorWrite.getVectorType() == write.getVectorType() &&
250 priorWrite.getPermutationMap() == write.getPermutationMap();
251}
252
254 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
255 bool testDynamicValueUsingBounds) {
256 // For simplicity only look at transfer of same type.
257 if (transferA.getVectorType() != transferB.getVectorType())
258 return false;
259 unsigned rankOffset = transferA.getLeadingShapedRank();
260 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
261 Value indexA = transferA.getIndices()[i];
262 Value indexB = transferB.getIndices()[i];
263 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
264 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
265
266 if (i < rankOffset) {
267 // For leading dimensions, if we can prove that index are different we
268 // know we are accessing disjoint slices.
269 if (cstIndexA.has_value() && cstIndexB.has_value()) {
270 if (*cstIndexA != *cstIndexB)
271 return true;
272 continue;
273 }
274 if (testDynamicValueUsingBounds) {
275 // First try to see if we can fully compose and simplify the affine
276 // expression as a fast track.
277 FailureOr<uint64_t> delta =
279 if (succeeded(delta) && *delta != 0)
280 return true;
281
282 FailureOr<bool> testEqual =
284 if (succeeded(testEqual) && !testEqual.value())
285 return true;
286 }
287 } else {
288 // For this dimension, we slice a part of the memref we need to make sure
289 // the intervals accessed don't overlap.
290 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
291 if (cstIndexA.has_value() && cstIndexB.has_value()) {
292 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
293 if (distance >= vectorDim)
294 return true;
295 continue;
296 }
297 if (testDynamicValueUsingBounds) {
298 // First try to see if we can fully compose and simplify the affine
299 // expression as a fast track.
300 FailureOr<int64_t> delta =
302 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
303 return true;
304
305 FailureOr<int64_t> computeDelta =
307 if (succeeded(computeDelta)) {
308 if (std::abs(computeDelta.value()) >= vectorDim)
309 return true;
310 }
311 }
312 }
313 }
314 return false;
315}
316
317bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
318 VectorTransferOpInterface transferB,
319 bool testDynamicValueUsingBounds) {
320 if (transferA.getBase() != transferB.getBase())
321 return false;
322 return isDisjointTransferIndices(transferA, transferB,
323 testDynamicValueUsingBounds);
324}
325
326// Helper to iterate over n-D vector slice elements. Calculate the next
327// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
328// Modifies the `position` in place. Returns a failure when `position` becomes
329// the end position.
330static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
332 ArrayRef<int64_t> offsets) {
333 for (auto [posInDim, dimSize, offsetInDim] :
334 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
335 ++posInDim;
336 if (posInDim < dimSize + offsetInDim)
337 return success();
338
339 // Carry the overflow to the next loop iteration.
340 posInDim = offsetInDim;
341 }
342
343 return failure();
344}
345
346/// Returns the integer numbers in `values`. `values` are expected to be
347/// constant operations.
350 llvm::transform(values, std::back_inserter(ints), [](Value value) {
351 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
352 assert(constOp && "Unexpected non-constant index");
353 return constOp.value();
354 });
355 return ints;
356}
357
358/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
359/// be constant operations.
362 llvm::transform(
363 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
364 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
365 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
366 });
367 return ints;
368}
369
370/// Convert `foldResults` into Values. Integer attributes are converted to
371/// constant op.
373 ArrayRef<OpFoldResult> foldResults) {
374 SmallVector<Value> values;
375 llvm::transform(foldResults, std::back_inserter(values),
376 [&](OpFoldResult foldResult) {
377 if (auto attr = dyn_cast<Attribute>(foldResult))
379 builder, loc, cast<IntegerAttr>(attr).getInt())
380 .getResult();
381
382 return cast<Value>(foldResult);
383 });
384 return values;
385}
386
387std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
388 if (value.getDefiningOp<vector::VectorScaleOp>())
389 return 1;
390 auto mul = value.getDefiningOp<arith::MulIOp>();
391 if (!mul)
392 return {};
393 auto lhs = mul.getLhs();
394 auto rhs = mul.getRhs();
395 if (lhs.getDefiningOp<vector::VectorScaleOp>())
396 return getConstantIntValue(rhs);
397 if (rhs.getDefiningOp<vector::VectorScaleOp>())
398 return getConstantIntValue(lhs);
399 return {};
400}
401
402/// Converts numeric attributes to the expected type. Supports
403/// integer-to-integer and float-to-integer conversions. Returns the original
404/// attribute if no conversion is needed or supported.
405static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
406 // Integer-to-integer conversion
407 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
408 if (auto intType = dyn_cast<IntegerType>(expectedType)) {
409 if (intAttr.getType() != expectedType)
410 return IntegerAttr::get(expectedType, intAttr.getInt());
411 }
412 return attr;
413 }
414
415 // Float-to-integer bitcast (preserves bit representation)
416 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
417 auto intType = dyn_cast<IntegerType>(expectedType);
418 if (!intType)
419 return attr;
420
421 APFloat floatVal = floatAttr.getValue();
422 APInt intVal = floatVal.bitcastToAPInt();
423 return IntegerAttr::get(expectedType, intVal);
424 }
425
426 return attr;
427}
428
429//===----------------------------------------------------------------------===//
430// CombiningKindAttr
431//===----------------------------------------------------------------------===//
432
433namespace mlir {
434namespace vector {
435namespace detail {
437 using KeyTy = uint64_t;
438
440
441 bool operator==(const KeyTy &key) const { return value == key; }
442
444 const KeyTy &key) {
445 return new (allocator.allocate<BitmaskEnumStorage>())
447 }
448
450};
451} // namespace detail
452} // namespace vector
453} // namespace mlir
454
455//===----------------------------------------------------------------------===//
456// VectorDialect
457//===----------------------------------------------------------------------===//
458
459namespace {
460/// This class defines the interface for handling inlining with vector dialect
461/// operations.
462struct VectorInlinerInterface : public DialectInlinerInterface {
463 using DialectInlinerInterface::DialectInlinerInterface;
464
465 /// All vector dialect ops can be inlined.
466 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
467 return true;
468 }
469};
470} // namespace
471
472void VectorDialect::initialize() {
473 addAttributes<
474#define GET_ATTRDEF_LIST
475#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
476 >();
477
478 addOperations<
479#define GET_OP_LIST
480#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
481 >();
482
483 addInterfaces<VectorInlinerInterface>();
484
485 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
486 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
487 YieldOp>();
488 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
489 TransferWriteOp>();
490 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
491 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
492 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
493}
494
495/// Materialize a single constant operation from a given attribute value with
496/// the desired resultant type.
497Operation *VectorDialect::materializeConstant(OpBuilder &builder,
498 Attribute value, Type type,
499 Location loc) {
500 if (matchPattern(value, ub::m_Poison()))
501 return value.getDialect().materializeConstant(builder, value, type, loc);
502
503 return arith::ConstantOp::materialize(builder, value, type, loc);
504}
505
507 return builder.getIntegerType(64);
508}
509
511 ArrayRef<int64_t> values) {
512 return builder.getI64ArrayAttr(values);
513}
514
515//===----------------------------------------------------------------------===//
516// MultiDimReductionOp
517//===----------------------------------------------------------------------===//
518
519void vector::MultiDimReductionOp::build(OpBuilder &builder,
520 OperationState &result, Value source,
521 Value acc, ArrayRef<bool> reductionMask,
522 CombiningKind kind) {
523 SmallVector<int64_t> reductionDims;
524 for (const auto &en : llvm::enumerate(reductionMask))
525 if (en.value())
526 reductionDims.push_back(en.index());
527 build(builder, result, kind, source, acc, reductionDims);
528}
529
530OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
531 // Single parallel dim, this is a noop.
532 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
533 return getSource();
534 return {};
535}
536
537std::optional<SmallVector<int64_t, 4>>
538MultiDimReductionOp::getShapeForUnroll() {
539 return llvm::to_vector<4>(getSourceVectorType().getShape());
540}
541
542LogicalResult MultiDimReductionOp::verify() {
543 SmallVector<int64_t> targetShape;
544 SmallVector<bool> scalableDims;
545 Type inferredReturnType;
546 auto sourceScalableDims = getSourceVectorType().getScalableDims();
547 for (auto [dimIdx, dimSize] :
548 llvm::enumerate(getSourceVectorType().getShape()))
549 if (!llvm::any_of(getReductionDims(),
550 [dimIdx = dimIdx](int64_t reductionDimIdx) {
551 return reductionDimIdx == static_cast<int64_t>(dimIdx);
552 })) {
553 targetShape.push_back(dimSize);
554 scalableDims.push_back(sourceScalableDims[dimIdx]);
555 }
556 // TODO: update to also allow 0-d vectors when available.
557 if (targetShape.empty())
558 inferredReturnType = getSourceVectorType().getElementType();
559 else
560 inferredReturnType = VectorType::get(
561 targetShape, getSourceVectorType().getElementType(), scalableDims);
562 if (getType() != inferredReturnType)
563 return emitOpError() << "destination type " << getType()
564 << " is incompatible with source type "
565 << getSourceVectorType();
566
567 return success();
568}
569
570/// Returns the mask type expected by this operation.
571Type MultiDimReductionOp::getExpectedMaskType() {
572 auto vecType = getSourceVectorType();
573 return VectorType::get(vecType.getShape(),
574 IntegerType::get(vecType.getContext(), /*width=*/1),
575 vecType.getScalableDims());
576}
577
578namespace {
579// Only unit dimensions that are being reduced are folded. If the dimension is
580// unit, but not reduced, it is not folded, thereby keeping the output type the
581// same. If not all dimensions which are reduced are of unit dimension, this
582// transformation does nothing. This is just a generalization of
583// ElideSingleElementReduction for ReduceOp.
584struct ElideUnitDimsInMultiDimReduction
585 : public OpRewritePattern<MultiDimReductionOp> {
586 using Base::Base;
587
588 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
589 PatternRewriter &rewriter) const override {
590 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
591 for (const auto &dim : enumerate(shape)) {
592 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
593 return failure();
594 }
595
596 // Vector mask setup.
597 OpBuilder::InsertionGuard guard(rewriter);
598 Operation *rootOp;
599 Value mask;
600 if (reductionOp.isMasked()) {
601 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
602 rootOp = reductionOp.getMaskingOp();
603 mask = reductionOp.getMaskingOp().getMask();
604 } else {
605 rootOp = reductionOp;
606 }
607
608 Location loc = reductionOp.getLoc();
609 Value acc = reductionOp.getAcc();
610 Value cast;
611 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
612 if (mask) {
613 VectorType newMaskType =
614 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
615 dstVecType.getScalableDims());
616 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
617 }
618 cast = vector::ShapeCastOp::create(
619 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
620 } else {
621 // This means we are reducing all the dimensions, and all reduction
622 // dimensions are of size 1. So a simple extraction would do.
623 if (mask)
624 mask = vector::ExtractOp::create(rewriter, loc, mask);
625 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
626 }
627
628 Value result =
629 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
630 cast, /*fastmath=*/nullptr, mask);
631 rewriter.replaceOp(rootOp, result);
632 return success();
633 }
634};
635} // namespace
636
637void MultiDimReductionOp::getCanonicalizationPatterns(
638 RewritePatternSet &results, MLIRContext *context) {
639 results.add<ElideUnitDimsInMultiDimReduction>(context);
640}
641
642//===----------------------------------------------------------------------===//
643// ReductionOp
644//===----------------------------------------------------------------------===//
645
646void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
647 CombiningKind kind, Value vector,
648 arith::FastMathFlags fastMathFlags) {
649 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
650}
651
652void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
653 CombiningKind kind, Value vector, Value acc,
654 arith::FastMathFlags fastMathFlags) {
655 build(builder, result,
656 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
657 acc, fastMathFlags);
658}
659
660LogicalResult ReductionOp::verify() {
661 // Verify for 0-D and 1-D vector.
662 int64_t rank = getSourceVectorType().getRank();
663 if (rank > 1)
664 return emitOpError("unsupported reduction rank: ") << rank;
665
666 // Verify supported reduction kind.
667 Type eltType = getDest().getType();
668 if (!isSupportedCombiningKind(getKind(), eltType))
669 return emitOpError("unsupported reduction type '")
670 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
671 << "'";
672
673 return success();
674}
675
676// MaskableOpInterface methods.
677
678/// Returns the mask type expected by this operation.
679Type ReductionOp::getExpectedMaskType() {
680 auto vecType = getSourceVectorType();
681 return VectorType::get(vecType.getShape(),
682 IntegerType::get(vecType.getContext(), /*width=*/1),
683 vecType.getScalableDims());
684}
685
687 OpBuilder &builder, Location loc,
688 Value vector) {
689 switch (op) {
690 case arith::AtomicRMWKind::addf:
691 case arith::AtomicRMWKind::addi:
692 return vector::ReductionOp::create(builder, vector.getLoc(),
693 CombiningKind::ADD, vector);
694 case arith::AtomicRMWKind::mulf:
695 case arith::AtomicRMWKind::muli:
696 return vector::ReductionOp::create(builder, vector.getLoc(),
697 CombiningKind::MUL, vector);
698 case arith::AtomicRMWKind::minimumf:
699 return vector::ReductionOp::create(builder, vector.getLoc(),
700 CombiningKind::MINIMUMF, vector);
701 case arith::AtomicRMWKind::mins:
702 return vector::ReductionOp::create(builder, vector.getLoc(),
703 CombiningKind::MINSI, vector);
704 case arith::AtomicRMWKind::minu:
705 return vector::ReductionOp::create(builder, vector.getLoc(),
706 CombiningKind::MINUI, vector);
707 case arith::AtomicRMWKind::maximumf:
708 return vector::ReductionOp::create(builder, vector.getLoc(),
709 CombiningKind::MAXIMUMF, vector);
710 case arith::AtomicRMWKind::maxs:
711 return vector::ReductionOp::create(builder, vector.getLoc(),
712 CombiningKind::MAXSI, vector);
713 case arith::AtomicRMWKind::maxu:
714 return vector::ReductionOp::create(builder, vector.getLoc(),
715 CombiningKind::MAXUI, vector);
716 case arith::AtomicRMWKind::andi:
717 return vector::ReductionOp::create(builder, vector.getLoc(),
718 CombiningKind::AND, vector);
719 case arith::AtomicRMWKind::ori:
720 return vector::ReductionOp::create(builder, vector.getLoc(),
721 CombiningKind::OR, vector);
722 case arith::AtomicRMWKind::minnumf:
723 return vector::ReductionOp::create(builder, vector.getLoc(),
724 CombiningKind::MINNUMF, vector);
725 case arith::AtomicRMWKind::maxnumf:
726 return vector::ReductionOp::create(builder, vector.getLoc(),
727 CombiningKind::MAXNUMF, vector);
728 case arith::AtomicRMWKind::xori:
729 return vector::ReductionOp::create(builder, vector.getLoc(),
730 CombiningKind::XOR, vector);
731 default:
732 (void)emitOptionalError(loc, "Reduction operation type not supported");
733 break;
734 }
735 return nullptr;
736}
737
738std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
739 return llvm::to_vector<4>(getSourceVectorType().getShape());
740}
741
742namespace {
743struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
744 using Base::Base;
745
746 LogicalResult matchAndRewrite(ReductionOp reductionOp,
747 PatternRewriter &rewriter) const override {
748 // Vector mask setup.
749 OpBuilder::InsertionGuard guard(rewriter);
750 auto maskableOp =
751 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
752 Operation *rootOp;
753 Value mask;
754 if (maskableOp.isMasked()) {
755 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
756 rootOp = maskableOp.getMaskingOp();
757 mask = maskableOp.getMaskingOp().getMask();
758 } else {
759 rootOp = reductionOp;
760 }
761
762 auto vectorType = reductionOp.getSourceVectorType();
763 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
764 return failure();
765
766 Location loc = reductionOp.getLoc();
767 if (mask)
768 mask = ExtractOp::create(rewriter, loc, mask);
769 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
770
771 if (Value acc = reductionOp.getAcc())
772 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
773 result, acc,
774 reductionOp.getFastmathAttr(), mask);
775
776 rewriter.replaceOp(rootOp, result);
777 return success();
778 }
779};
780} // namespace
781
782void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
783 MLIRContext *context) {
784 results.add<ElideSingleElementReduction>(context);
785}
786
787//===----------------------------------------------------------------------===//
788// ContractionOp
789//===----------------------------------------------------------------------===//
790
791void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
793 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
794 ArrayRef<IteratorType> iteratorTypes) {
795 result.addOperands({lhs, rhs, acc});
796 result.addTypes(acc.getType());
797 result.addAttribute(
798 getIndexingMapsAttrName(result.name),
799 builder.getAffineMapArrayAttr(
800 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
801 result.addAttribute(
802 getIteratorTypesAttrName(result.name),
803 builder.getArrayAttr(llvm::map_to_vector(
804 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
805 return IteratorTypeAttr::get(builder.getContext(), t);
806 })));
807}
808
809void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
811 ArrayAttr indexingMaps,
812 ArrayAttr iteratorTypes) {
813 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
814 ContractionOp::getDefaultKind());
815}
816
817void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
819 ArrayAttr indexingMaps,
820 ArrayAttr iteratorTypes, CombiningKind kind) {
821 result.addOperands({lhs, rhs, acc});
822 result.addTypes(acc.getType());
823 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
824 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
825 result.addAttribute(getKindAttrName(result.name),
826 CombiningKindAttr::get(builder.getContext(), kind));
827}
828
829ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
835 Type resultType;
836 auto loc = parser.getCurrentLocation();
837 DictionaryAttr dictAttr;
838 // TODO: Unify linalg op attribute parsing.
839 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
840 parser.parseComma() || parser.parseOperand(rhsInfo) ||
841 parser.parseComma() || parser.parseOperand(accInfo) ||
842 parser.parseTrailingOperandList(masksInfo) ||
843 parser.parseOptionalAttrDict(result.attributes) ||
844 parser.parseColonTypeList(types) ||
845 parser.parseKeywordType("into", resultType) ||
846 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
847 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
848 parser.resolveOperand(accInfo, resultType, result.operands) ||
849 parser.addTypeToList(resultType, result.types))
850 return failure();
851 result.attributes.append(dictAttr.getValue().begin(),
852 dictAttr.getValue().end());
853
854 // Convert array of string into an array of IteratyType enums. This is needed,
855 // because tests still use the old format when 'iterator_types' attribute is
856 // represented as an array of strings.
857 // TODO: Remove this conversion once tests are fixed.
858 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
859 result.attributes.get(getIteratorTypesAttrName(result.name)));
860 if (!iteratorTypes) {
861 return parser.emitError(loc)
862 << "expected " << getIteratorTypesAttrName(result.name)
863 << " array attribute";
864 }
865
866 SmallVector<Attribute> iteratorTypeAttrs;
867
868 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
869 auto maybeIteratorType = symbolizeIteratorType(s);
870 if (!maybeIteratorType.has_value())
871 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
872
873 iteratorTypeAttrs.push_back(
874 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
875 }
876 result.attributes.set(getIteratorTypesAttrName(result.name),
877 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
878
879 if (!result.attributes.get(getKindAttrName(result.name))) {
880 result.addAttribute(
881 getKindAttrName(result.name),
882 CombiningKindAttr::get(result.getContext(),
883 ContractionOp::getDefaultKind()));
884 }
885 if (masksInfo.empty())
886 return success();
887 if (masksInfo.size() != 2)
888 return parser.emitError(parser.getNameLoc(),
889 "expected zero or exactly 2 vector mask operands");
890 auto lhsType = llvm::cast<VectorType>(types[0]);
891 auto rhsType = llvm::cast<VectorType>(types[1]);
892 auto maskElementType = parser.getBuilder().getI1Type();
893 std::array<VectorType, 2> maskTypes = {
894 VectorType::Builder(lhsType).setElementType(maskElementType),
895 VectorType::Builder(rhsType).setElementType(maskElementType)};
896 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
897 return failure();
898 return success();
899}
900
901void ContractionOp::print(OpAsmPrinter &p) {
902 // TODO: Unify printing code with linalg ops.
903 auto attrNames = getTraitAttrNames();
904 llvm::StringSet<> traitAttrsSet;
905 traitAttrsSet.insert_range(attrNames);
907 for (auto attr : (*this)->getAttrs()) {
908 if (attr.getName() == getIteratorTypesAttrName()) {
909 auto iteratorTypes =
910 llvm::cast<ArrayAttr>(attr.getValue())
911 .getAsValueRange<IteratorTypeAttr, IteratorType>();
912 // Convert IteratorType enums into the string representation. This is
913 // needed, because tests still use the old format when 'iterator_types'
914 // attribute is represented as an array of strings.
915 // TODO: Remove this conversion once tests are fixed.
916 SmallVector<Attribute> iteratorTypeNames =
917 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) -> Attribute {
918 return StringAttr::get(getContext(), stringifyIteratorType(t));
919 });
920
921 attrs.emplace_back(getIteratorTypesAttrName(),
922 ArrayAttr::get(getContext(), iteratorTypeNames));
923 } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
924 attrs.push_back(attr);
925 }
926
927 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
928 p << " " << dictAttr << " " << getLhs() << ", ";
929 p << getRhs() << ", " << getAcc();
930
931 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
932 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
933 << getResultType();
934}
935
936static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
937 const std::vector<std::pair<int64_t, int64_t>> &map) {
938 for (auto &dimPair : map) {
939 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
940 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
941 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
942 return false;
943 }
944 return true;
945}
946
947static LogicalResult verifyOutputShape(
948 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
949 Type resType,
950 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
951 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
952 DenseSet<int64_t> lhsContractingDimSet;
953 DenseSet<int64_t> rhsContractingDimSet;
954 for (auto &dimPair : contractingDimMap) {
955 lhsContractingDimSet.insert(dimPair.first);
956 rhsContractingDimSet.insert(dimPair.second);
957 }
958 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
959 llvm::make_second_range(batchDimMap));
960
961 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
962 SmallVector<int64_t, 4> expectedResultDims;
963 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
964 if (lhsContractingDimSet.count(i) > 0)
965 continue;
966 expectedResultDims.push_back(lhsType.getDimSize(i));
967 }
968
969 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
970 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
971 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
972 continue;
973 expectedResultDims.push_back(rhsType.getDimSize(i));
974 }
975
976 // Verify 'expectedResultDims'.
977 if (expectedResultDims.empty()) {
978 // No batch or free dimension implies a scalar result.
979 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
980 return op.emitOpError("invalid accumulator/result vector shape");
981 } else {
982 // At least one batch or free dimension implies a vector result.
983 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
984 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
985 if (!resVectorType || !accVectorType)
986 return op.emitOpError("invalid accumulator/result vector shape");
987
988 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
989 // types fully define the result vector type. This assumes the affine maps
990 // are well-formed, which must have been verified already.
991 MLIRContext *ctx = op.getContext();
992 AffineMap lhsMap = op.getIndexingMapsArray()[0];
993 AffineMap rhsMap = op.getIndexingMapsArray()[1];
994 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
995 return op.emitOpError(
996 "expected all dimensions to be either a LHS or a RHS dimension");
998 for (auto pair :
999 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1000 VectorType v = pair.first;
1001 auto map = pair.second;
1002 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1003 unsigned pos = map.getDimPosition(idx);
1004 if (!extents[pos])
1005 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
1006 }
1007 }
1008 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
1009 return op.emitOpError("expected all dimensions to get an extent as "
1010 "either a LHS or a RHS dimension");
1011
1012 AffineMap resMap = op.getIndexingMapsArray()[2];
1013 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1014 /*symbolCount=*/0, extents, ctx);
1015 // Compose the resMap with the extentsMap, which is a constant map.
1016 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1017 assert(llvm::all_of(expectedMap.getResults(),
1018 llvm::IsaPred<AffineConstantExpr>) &&
1019 "expected constant extent along all dimensions.");
1020 // Extract the expected shape and build the type.
1021 auto expectedShape =
1022 llvm::map_to_vector<4>(expectedMap.getResults(), [](AffineExpr e) {
1023 return cast<AffineConstantExpr>(e).getValue();
1024 });
1025 auto expected =
1026 VectorType::get(expectedShape, resVectorType.getElementType(),
1027 resVectorType.getScalableDims());
1028 if (resVectorType != expected || accVectorType != expected)
1029 return op.emitOpError(
1030 "invalid accumulator/result vector shape, expected: ")
1031 << expected;
1032 }
1033 return success();
1034}
1035
1036LogicalResult ContractionOp::verify() {
1037 VectorType lhsType = getLhsType();
1038 VectorType rhsType = getRhsType();
1039 Type accType = getAccType();
1040 Type resType = getResultType();
1041
1042 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1043 if (!lhsType.getElementType().isSignlessInteger())
1044 return emitOpError("only supports signless integer types");
1045 }
1046
1047 // Verify that an indexing map was specified for each vector operand.
1048 if (getIndexingMapsArray().size() != 3)
1049 return emitOpError("expected an indexing map for each vector operand");
1050
1051 // Verify that each index map has 'numIterators' inputs, no symbols, and
1052 // that the number of map outputs equals the rank of its associated
1053 // vector operand.
1054 unsigned numIterators = getIteratorTypes().getValue().size();
1055 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1056 auto index = it.index();
1057 auto map = it.value();
1058 if (map.getNumSymbols() != 0)
1059 return emitOpError("expected indexing map ")
1060 << index << " to have no symbols";
1061 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1062 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1063 // Verify that the map has the right number of inputs, outputs, and indices.
1064 // This also correctly accounts for (..) -> () for rank-0 results.
1065 if (map.getNumDims() != numIterators)
1066 return emitOpError("expected indexing map ")
1067 << index << " to have " << numIterators << " number of inputs";
1068 if (map.getNumResults() != rank)
1069 return emitOpError("expected indexing map ")
1070 << index << " to have " << rank << " number of outputs";
1071 if (!map.isProjectedPermutation())
1072 return emitOpError("expected indexing map ")
1073 << index << " to be a projected permutation of its inputs";
1074 }
1075
1076 auto contractingDimMap = getContractingDimMap();
1077 auto batchDimMap = getBatchDimMap();
1078
1079 // Verify at least one contracting dimension pair was specified.
1080 if (contractingDimMap.empty())
1081 return emitOpError("expected at least one contracting dimension pair");
1082
1083 // Verify contracting dimension map was properly constructed.
1084 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1085 return emitOpError("invalid contracting dimension map");
1086
1087 // Verify batch dimension map was properly constructed.
1088 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1089 return emitOpError("invalid batch dimension map");
1090
1091 // Verify 'accType' and 'resType' shape.
1092 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1093 contractingDimMap, batchDimMap)))
1094 return failure();
1095
1096 if (!getKindAttr()) {
1097 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
1098 "'vector.kind<add>')");
1099 }
1100
1101 // Verify supported combining kind.
1102 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1103 auto elementType = vectorType ? vectorType.getElementType() : resType;
1104 if (!isSupportedCombiningKind(getKind(), elementType))
1105 return emitOpError("unsupported contraction type");
1106
1107 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1108 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1109}
1110
1111// MaskableOpInterface methods.
1112
1113/// Returns the mask type expected by this operation. Mostly used for
1114/// verification purposes. It requires the operation to be vectorized."
1115Type ContractionOp::getExpectedMaskType() {
1116 auto indexingMaps = this->getIndexingMapsArray();
1117 AffineMap lhsIdxMap = indexingMaps[0];
1118 AffineMap rhsIdxMap = indexingMaps[1];
1119 VectorType lhsType = this->getLhsType();
1120 VectorType rhsType = this->getRhsType();
1121
1122 unsigned numVecDims = lhsIdxMap.getNumDims();
1123 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1124 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1125
1126 // Using the information in the indexing maps, extract the size of each
1127 // dimension in the vector.contract operation from the two input operands.
1128 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1129 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1130 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1131 lhsType.getScalableDims()[dimIdx];
1132 }
1133 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1134 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1135 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1136 rhsType.getScalableDims()[dimIdx];
1137 }
1138
1139 assert(ShapedType::isStaticShape(maskShape) &&
1140 "Mask shape couldn't be computed");
1141
1142 return VectorType::get(maskShape,
1143 IntegerType::get(lhsType.getContext(), /*width=*/1),
1144 maskShapeScalableDims);
1145}
1146
1147SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1148 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1149 getIteratorTypesAttrName(), getKindAttrName()};
1150}
1151
1153 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1154 if (targetExpr == map.getResult(i))
1155 return i;
1156 return -1;
1157}
1158
1159static std::vector<std::pair<int64_t, int64_t>>
1160getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1161 IteratorType targetIteratorType, MLIRContext *context) {
1162 std::vector<std::pair<int64_t, int64_t>> dimMap;
1163 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1164 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1165 if (iteratorType != targetIteratorType)
1166 continue;
1167 // Search lhs/rhs map results for 'targetExpr'.
1168 auto targetExpr = getAffineDimExpr(it.index(), context);
1169 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1170 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1171 if (lhsDim >= 0 && rhsDim >= 0)
1172 dimMap.emplace_back(lhsDim, rhsDim);
1173 }
1174 return dimMap;
1175}
1176
1177void ContractionOp::getIterationBounds(
1178 SmallVectorImpl<int64_t> &iterationBounds) {
1179 auto lhsShape = getLhsType().getShape();
1180 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1181 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1182 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1183 // Search lhs/rhs map results for 'targetExpr'.
1184 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1185 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1186 if (iteratorType == IteratorType::reduction) {
1187 // Get reduction dim size from lhs shape (same size in rhsShape).
1188 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1189 assert(lhsDimIndex >= 0);
1190 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1191 continue;
1192 }
1193 // Get parallel dimension size from result shape.
1194 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1195 assert(resDimIndex >= 0);
1196 assert(resVectorType != nullptr);
1197 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1198 }
1199}
1200
1201void ContractionOp::getIterationIndexMap(
1202 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1203 unsigned numMaps = getIndexingMapsArray().size();
1204 iterationIndexMap.resize(numMaps);
1205 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1206 auto index = it.index();
1207 auto map = it.value();
1208 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1209 auto dim = cast<AffineDimExpr>(map.getResult(i));
1210 iterationIndexMap[index][dim.getPosition()] = i;
1211 }
1212 }
1213}
1214
1215std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1216 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1217 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1218 getContext());
1219}
1220
1221std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1222 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1223 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1224 getContext());
1225}
1226
1227std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1229 getIterationBounds(shape);
1230 return shape;
1231}
1232
1233/// Return a fused vector::ContractionOp which represents a patterns such as:
1234///
1235/// ```mlir
1236/// %c0 = vector.constant 0: ...
1237/// %c = vector.contract %a, %b, %c0: ...
1238/// %e = add %c, %d: ...
1239/// ```
1240///
1241/// by:
1242///
1243/// ```mlir
1244/// %e = vector.contract %a, %b, %d: ...
1245/// ```
1246///
1247/// Return null if the canonicalization does not apply.
1248// TODO: This should be a folding of Add into Contract in core but while they
1249// live in different dialects, it is not possible without unnatural
1250// dependencies.
1251template <typename AddOpType>
1252struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1253 using OpRewritePattern<AddOpType>::OpRewritePattern;
1254
1255 LogicalResult matchAndRewrite(AddOpType addOp,
1256 PatternRewriter &rewriter) const override {
1257 auto canonicalize = [&](Value maybeContraction,
1258 Value otherOperand) -> vector::ContractionOp {
1259 vector::ContractionOp contractionOp =
1260 dyn_cast_or_null<vector::ContractionOp>(
1261 maybeContraction.getDefiningOp());
1262 if (!contractionOp)
1263 return vector::ContractionOp();
1264 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1265 contractionOp.getAcc().getDefiningOp())) {
1266 if (maybeZero.getValue() ==
1267 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1268 IRMapping bvm;
1269 bvm.map(contractionOp.getAcc(), otherOperand);
1270 auto newContraction =
1271 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1272 rewriter.replaceOp(addOp, newContraction.getResult());
1273 return newContraction;
1274 }
1275 }
1276 return vector::ContractionOp();
1277 };
1278
1279 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1280 vector::ContractionOp contract = canonicalize(a, b);
1281 contract = contract ? contract : canonicalize(b, a);
1282 return contract ? success() : failure();
1283 }
1284};
1285
1286void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1287 MLIRContext *context) {
1290}
1291
1292// Returns `true` if `index` is either within [0, maxIndex) or equal to
1293// `poisonValue`.
1295 int64_t maxIndex) {
1296 return index == poisonValue || (index >= 0 && index < maxIndex);
1297}
1298
1299//===----------------------------------------------------------------------===//
1300// ExtractOp
1301//===----------------------------------------------------------------------===//
1302
1303void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1304 SetIntRangeFn setResultRanges) {
1305 setResultRanges(getResult(), argRanges.front());
1306}
1307
1308void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1309 Value source) {
1310 auto vectorTy = cast<VectorType>(source.getType());
1311 build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1312}
1313
1314void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1315 Value source, int64_t position) {
1316 build(builder, result, source, ArrayRef<int64_t>{position});
1317}
1318
1319void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1320 Value source, OpFoldResult position) {
1321 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1322}
1323
1324void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1325 Value source, ArrayRef<int64_t> position) {
1326 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1327 builder.getDenseI64ArrayAttr(position));
1328}
1329
1330void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1331 Value source, ArrayRef<OpFoldResult> position) {
1332 SmallVector<int64_t> staticPos;
1333 SmallVector<Value> dynamicPos;
1334 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1335 build(builder, result, source, dynamicPos,
1336 builder.getDenseI64ArrayAttr(staticPos));
1337}
1338
1339LogicalResult
1340ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1341 ExtractOp::Adaptor adaptor,
1342 SmallVectorImpl<Type> &inferredReturnTypes) {
1343 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1344 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1345 vectorType.getRank()) {
1346 inferredReturnTypes.push_back(vectorType.getElementType());
1347 } else {
1348 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1349 vectorType.getRank());
1350 inferredReturnTypes.push_back(VectorType::get(
1351 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1352 vectorType.getScalableDims().drop_front(n)));
1353 }
1354 return success();
1355}
1356
1357bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1358 // Allow extracting 1-element vectors instead of scalars.
1359 auto isCompatible = [](TypeRange l, TypeRange r) {
1360 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1361 return vectorType && vectorType.getShape().equals({1}) &&
1362 vectorType.getElementType() == r.front();
1363 };
1364 if (l.size() == 1 && r.size() == 1 &&
1365 (isCompatible(l, r) || isCompatible(r, l)))
1366 return true;
1367 return l == r;
1368}
1369
1370LogicalResult vector::ExtractOp::verify() {
1371 if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1372 if (resTy.getRank() == 0)
1373 return emitError(
1374 "expected a scalar instead of a 0-d vector as the result type");
1375
1376 // Note: This check must come before getMixedPosition() to prevent a crash.
1377 auto dynamicMarkersCount =
1378 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1379 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1380 return emitOpError(
1381 "mismatch between dynamic and static positions (kDynamic marker but no "
1382 "corresponding dynamic position) -- this can only happen due to an "
1383 "incorrect fold/rewrite");
1384 auto position = getMixedPosition();
1385 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1386 return emitOpError(
1387 "expected position attribute of rank no greater than vector rank");
1388 for (auto [idx, pos] : llvm::enumerate(position)) {
1389 if (auto attr = dyn_cast<Attribute>(pos)) {
1390 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1392 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1393 return emitOpError("expected position attribute #")
1394 << (idx + 1)
1395 << " to be a non-negative integer smaller than the "
1396 "corresponding vector dimension or poison (-1)";
1397 }
1398 }
1399 }
1400 return success();
1401}
1402
1403template <typename IntType>
1405 return llvm::map_to_vector<4>(
1406 arrayAttr.getAsRange<IntegerAttr>(),
1407 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1408}
1409
1410/// Fold the result of chains of ExtractOp in place by simply concatenating the
1411/// positions.
1412static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1413 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1414 return failure();
1415
1416 // TODO: Canonicalization for dynamic position not implemented yet.
1417 if (extractOp.hasDynamicPosition())
1418 return failure();
1419
1420 SmallVector<int64_t> globalPosition;
1421 ExtractOp currentOp = extractOp;
1422 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1423 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1424 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1425 currentOp = nextOp;
1426 // TODO: Canonicalization for dynamic position not implemented yet.
1427 if (currentOp.hasDynamicPosition())
1428 return failure();
1429 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1430 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1431 }
1432 extractOp.setOperand(0, currentOp.getSource());
1433 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1434 OpBuilder b(extractOp.getContext());
1435 std::reverse(globalPosition.begin(), globalPosition.end());
1436 extractOp.setStaticPosition(globalPosition);
1437 return success();
1438}
1439
1440namespace {
1441/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1442/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1443/// Compose TransposeOp permutations as we walk back.
1444/// This helper class keeps an updated extraction position `extractPosition`
1445/// with extra trailing sentinels.
1446/// The sentinels encode the internal transposition status of the result vector.
1447/// As we iterate, extractPosition is permuted and updated.
1448class ExtractFromInsertTransposeChainState {
1449public:
1450 ExtractFromInsertTransposeChainState(ExtractOp e);
1451
1452 /// Iterate over producing insert and transpose ops until we find a fold.
1453 Value fold();
1454
1455private:
1456 /// Return true if the vector at position `a` is contained within the vector
1457 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1458 /// is a prefix of `b`.
1459 template <typename ContainerA, typename ContainerB>
1460 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1461 return a.size() <= b.size() &&
1462 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1463 }
1464
1465 /// Return true if the vector at position `a` intersects the vector at
1466 /// position `b`. Under insert/extract semantics, this is the same as equality
1467 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1468 /// Comparison is on the common prefix (i.e. zip).
1469 template <typename ContainerA, typename ContainerB>
1470 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1471 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1472 if (elemA < 0 || elemB < 0)
1473 continue;
1474 if (elemA != elemB)
1475 return false;
1476 }
1477 return true;
1478 }
1479
1480 /// Folding is only possible in the absence of an internal permutation in the
1481 /// result vector.
1482 bool canFold() {
1483 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1484 }
1485
1486 // Helper to get the next defining op of interest.
1487 void updateStateForNextIteration(Value v) {
1488 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1489 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1490 };
1491
1492 // Case 1. If we hit a transpose, just compose the map and iterate.
1493 // Invariant: insert + transpose do not change rank, we can always compose.
1494 LogicalResult handleTransposeOp();
1495
1496 // Case 2: the insert position matches extractPosition exactly, early return.
1497 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1498
1499 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1500 /// portion of the source of the insert.
1501 /// Example:
1502 /// ```
1503 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1504 /// // extractPosition == [1, 2, 3]
1505 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1506 /// // can fold to vector.extract %source[0, 3]
1507 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1508 /// ```
1509 /// To traverse through %source, we need to set the leading dims to 0 and
1510 /// drop the extra leading dims.
1511 /// This method updates the internal state.
1512 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1513
1514 /// Try to fold in place to extract(source, extractPosition) and return the
1515 /// folded result. Return null if folding is not possible (e.g. due to an
1516 /// internal transposition in the result).
1517 Value tryToFoldExtractOpInPlace(Value source);
1518
1519 ExtractOp extractOp;
1520 int64_t vectorRank;
1521 int64_t extractedRank;
1522
1523 InsertOp nextInsertOp;
1524 TransposeOp nextTransposeOp;
1525
1526 /// Sentinel values that encode the internal permutation status of the result.
1527 /// They are set to (-1, ... , -k) at the beginning and appended to
1528 /// `extractPosition`.
1529 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1530 /// ensure that there is no internal transposition.
1531 /// Internal transposition cannot be accounted for with a folding pattern.
1532 // TODO: We could relax the internal transposition with an extra transposition
1533 // operation in a future canonicalizer.
1534 SmallVector<int64_t> sentinels;
1535 SmallVector<int64_t> extractPosition;
1536};
1537} // namespace
1538
1539ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1540 ExtractOp e)
1541 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1542 extractedRank(extractOp.getNumIndices()) {
1543 assert(vectorRank >= extractedRank && "Extracted position overflow");
1544 sentinels.reserve(vectorRank - extractedRank);
1545 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1546 sentinels.push_back(-(i + 1));
1547 extractPosition.assign(extractOp.getStaticPosition().begin(),
1548 extractOp.getStaticPosition().end());
1549 llvm::append_range(extractPosition, sentinels);
1550}
1551
1552// Case 1. If we hit a transpose, just compose the map and iterate.
1553// Invariant: insert + transpose do not change rank, we can always compose.
1554LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1555 // TODO: Canonicalization for dynamic position not implemented yet.
1556 if (extractOp.hasDynamicPosition())
1557 return failure();
1558
1559 if (!nextTransposeOp)
1560 return failure();
1562 nextTransposeOp.getPermutation(), extractOp.getContext()));
1564 return success();
1565}
1566
1567// Case 2: the insert position matches extractPosition exactly, early return.
1568LogicalResult
1569ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1570 Value &res) {
1571 // TODO: Canonicalization for dynamic position not implemented yet.
1572 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1573 return failure();
1574
1575 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1576 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1577 return failure();
1578 // Case 2.a. early-exit fold.
1579 res = nextInsertOp.getValueToStore();
1580 // Case 2.b. if internal transposition is present, canFold will be false.
1581 return success(canFold());
1582}
1583
1584/// Case 3: if inserted position is a prefix of extractPosition,
1585/// extract a portion of the source of the insertion.
1586/// This method updates the internal state.
1587LogicalResult
1588ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1589 // TODO: Canonicalization for dynamic position not implemented yet.
1590 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1591 return failure();
1592
1593 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1594 if (!isContainedWithin(insertedPos, extractPosition))
1595 return failure();
1596 // Set leading dims to zero.
1597 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1598 // Drop extra leading dims.
1599 extractPosition.erase(extractPosition.begin(),
1600 extractPosition.begin() + insertedPos.size());
1601 extractedRank = extractPosition.size() - sentinels.size();
1602 // Case 3.a. early-exit fold (break and delegate to post-while path).
1603 res = nextInsertOp.getValueToStore();
1604 // Case 3.b. if internal transposition is present, canFold will be false.
1605 return success();
1606}
1607
1608/// Try to fold in place to extract(source, extractPosition) and return the
1609/// folded result. Return null if folding is not possible (e.g. due to an
1610/// internal transposition in the result).
1611Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1612 Value source) {
1613 // TODO: Canonicalization for dynamic position not implemented yet.
1614 if (extractOp.hasDynamicPosition())
1615 return Value();
1616
1617 // If we can't fold (either internal transposition, or nothing to fold), bail.
1618 bool nothingToFold = (source == extractOp.getSource());
1619 if (nothingToFold || !canFold())
1620 return Value();
1621
1622 // Otherwise, fold by updating the op inplace and return its result.
1623 OpBuilder b(extractOp.getContext());
1624 extractOp.setStaticPosition(
1625 ArrayRef(extractPosition).take_front(extractedRank));
1626 extractOp.getSourceMutable().assign(source);
1627 return extractOp.getResult();
1628}
1629
1630/// Iterate over producing insert and transpose ops until we find a fold.
1631Value ExtractFromInsertTransposeChainState::fold() {
1632 // TODO: Canonicalization for dynamic position not implemented yet.
1633 if (extractOp.hasDynamicPosition())
1634 return Value();
1635
1636 Value valueToExtractFrom = extractOp.getSource();
1637 updateStateForNextIteration(valueToExtractFrom);
1638 while (nextInsertOp || nextTransposeOp) {
1639 // Case 1. If we hit a transpose, just compose the map and iterate.
1640 // Invariant: insert + transpose do not change rank, we can always compose.
1641 if (succeeded(handleTransposeOp())) {
1642 valueToExtractFrom = nextTransposeOp.getVector();
1643 updateStateForNextIteration(valueToExtractFrom);
1644 continue;
1645 }
1646
1647 Value result;
1648 // Case 2: the position match exactly.
1649 if (succeeded(handleInsertOpWithMatchingPos(result)))
1650 return result;
1651
1652 // Case 3: if the inserted position is a prefix of extractPosition, we can
1653 // just extract a portion of the source of the insert.
1654 if (succeeded(handleInsertOpWithPrefixPos(result)))
1655 return tryToFoldExtractOpInPlace(result);
1656
1657 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1658 // values. This is a more difficult case and we bail.
1659 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1660 if (isContainedWithin(extractPosition, insertedPos) ||
1661 intersectsWhereNonNegative(extractPosition, insertedPos))
1662 return Value();
1663
1664 // Case 5: No intersection, we forward the extract to insertOp.dest().
1665 valueToExtractFrom = nextInsertOp.getDest();
1666 updateStateForNextIteration(valueToExtractFrom);
1667 }
1668 // If after all this we can fold, go for it.
1669 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1670}
1671
1672/// Returns true if the operation has a 0-D vector type operand or result.
1674 auto hasZeroDimVectorType = [](Type type) -> bool {
1675 auto vecType = dyn_cast<VectorType>(type);
1676 return vecType && vecType.getRank() == 0;
1677 };
1678
1679 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1680 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1681}
1682
1683/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1684/// considered to be 'broadcastlike'.
1685static bool isBroadcastLike(Operation *op) {
1686 if (isa<BroadcastOp>(op))
1687 return true;
1688
1689 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1690 if (!shapeCast)
1691 return false;
1692
1693 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1694 // Checking that the destination shape has a prefix of 1s is not sufficient,
1695 // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1696 // is that the source shape is a suffix of the destination shape.
1697 VectorType srcType = shapeCast.getSourceVectorType();
1698 ArrayRef<int64_t> srcShape = srcType.getShape();
1699 uint64_t srcRank = srcType.getRank();
1700 ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1701 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1702}
1703
1704/// Fold extract(broadcast(X)) to either extract(X) or just X.
1705///
1706/// Example:
1707///
1708/// broadcast extract [1][2]
1709/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1710///
1711/// becomes
1712/// extract [1]
1713/// (3,4) -------------------------------------> (4)
1714///
1715///
1716/// The variable names used in this implementation correspond to the above
1717/// shapes as,
1718///
1719/// - (3, 4) is `input` shape.
1720/// - (2, 3, 4) is `broadcast` shape.
1721/// - (4) is `extract` shape.
1722///
1723/// This folding is possible when the suffix of `input` shape is the same as
1724/// `extract` shape.
1725static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1726
1727 Operation *defOp = extractOp.getSource().getDefiningOp();
1728 if (!defOp || !isBroadcastLike(defOp))
1729 return Value();
1730
1731 Value input = defOp->getOperand(0);
1732
1733 // Replace extract(broadcast(X)) with X
1734 if (extractOp.getType() == input.getType())
1735 return input;
1736
1737 // Get required types and ranks in the chain
1738 // input -> broadcast -> extract
1739 // (scalars are treated as rank-0).
1740 auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1741 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1742 unsigned inputRank = inputType ? inputType.getRank() : 0;
1743 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1744 unsigned extractRank = extractType ? extractType.getRank() : 0;
1745
1746 // Cannot do without the broadcast if overall the rank increases.
1747 if (extractRank > inputRank)
1748 return Value();
1749
1750 // The above condition guarantees that input is a vector.
1751 assert(inputType && "input must be a vector type because of previous checks");
1752 ArrayRef<int64_t> inputShape = inputType.getShape();
1753
1754 // In the case where there is a broadcast dimension in the suffix, it is not
1755 // possible to replace extract(broadcast(X)) with extract(X). Example:
1756 //
1757 // broadcast extract
1758 // (1) --------> (3,4) ------> (4)
1759 if (extractType &&
1760 extractType.getShape() != inputShape.take_back(extractRank))
1761 return Value();
1762
1763 // Replace extract(broadcast(X)) with extract(X).
1764 // First, determine the new extraction position.
1765 unsigned deltaOverall = inputRank - extractRank;
1766 unsigned deltaBroadcast = broadcastRank - inputRank;
1767 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1768 SmallVector<OpFoldResult> newPositions(deltaOverall);
1769 IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1770 for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1771 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1772 }
1773 auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1774 extractOp->setOperands(
1775 llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1776 extractOp.setStaticPosition(staticPos);
1777 return extractOp.getResult();
1778}
1779
1780/// Fold extractOp coming from ShuffleOp.
1781///
1782/// Example:
1783///
1784/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1785/// : vector<8xf32>, vector<8xf32>
1786/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1787/// ->
1788/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1789///
1790static Value foldExtractFromShuffle(ExtractOp extractOp) {
1791 // Dynamic positions are not folded as the resulting code would be more
1792 // complex than the input code.
1793 if (extractOp.hasDynamicPosition())
1794 return Value();
1795
1796 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1797 if (!shuffleOp)
1798 return Value();
1799
1800 // TODO: 0-D or multi-dimensional vectors not supported yet.
1801 if (shuffleOp.getResultVectorType().getRank() != 1)
1802 return Value();
1803
1804 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1805 auto shuffleMask = shuffleOp.getMask();
1806 int64_t extractIdx = extractOp.getStaticPosition()[0];
1807 int64_t shuffleIdx = shuffleMask[extractIdx];
1808
1809 // Find the shuffled vector to extract from based on the shuffle index.
1810 if (shuffleIdx < inputVecSize) {
1811 extractOp.setOperand(0, shuffleOp.getV1());
1812 extractOp.setStaticPosition({shuffleIdx});
1813 } else {
1814 extractOp.setOperand(0, shuffleOp.getV2());
1815 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1816 }
1817
1818 return extractOp.getResult();
1819}
1820
1821// Fold extractOp with source coming from ShapeCast op.
1822static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1823 // TODO: Canonicalization for dynamic position not implemented yet.
1824 if (extractOp.hasDynamicPosition())
1825 return Value();
1826
1827 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1828 if (!shapeCastOp)
1829 return Value();
1830
1831 // Get the nth dimension size starting from lowest dimension.
1832 auto getDimReverse = [](VectorType type, int64_t n) {
1833 return type.getShape().take_back(n + 1).front();
1834 };
1835 int64_t destinationRank =
1836 llvm::isa<VectorType>(extractOp.getType())
1837 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1838 : 0;
1839 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1840 return Value();
1841 if (destinationRank > 0) {
1842 auto destinationType =
1843 llvm::cast<VectorType>(extractOp.getResult().getType());
1844 for (int64_t i = 0; i < destinationRank; i++) {
1845 // The lowest dimension of the destination must match the lowest
1846 // dimension of the shapecast op source.
1847 // TODO: This case could be support in a canonicalization pattern.
1848 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1849 getDimReverse(destinationType, i))
1850 return Value();
1851 }
1852 }
1853 // Extract the strides associated with the extract op vector source. Then use
1854 // this to calculate a linearized position for the extract.
1855 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1856 std::reverse(extractedPos.begin(), extractedPos.end());
1858 int64_t stride = 1;
1859 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1860 strides.push_back(stride);
1861 stride *=
1862 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1863 }
1864
1865 int64_t position = linearize(extractedPos, strides);
1866 // Then extract the strides associated to the shapeCast op vector source and
1867 // delinearize the position using those strides.
1868 SmallVector<int64_t, 4> newStrides;
1869 int64_t numDimension =
1870 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1871 stride = 1;
1872 for (int64_t i = 0; i < numDimension; i++) {
1873 newStrides.push_back(stride);
1874 stride *=
1875 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1876 }
1877 std::reverse(newStrides.begin(), newStrides.end());
1878 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1879 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1880 OpBuilder b(extractOp.getContext());
1881 extractOp.setStaticPosition(newPosition);
1882 extractOp.setOperand(0, shapeCastOp.getSource());
1883 return extractOp.getResult();
1884}
1885
1886/// Fold an ExtractOp from ExtractStridedSliceOp.
1887static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1888 // TODO: Canonicalization for dynamic position not implemented yet.
1889 if (extractOp.hasDynamicPosition())
1890 return Value();
1891
1892 auto extractStridedSliceOp =
1893 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1894 if (!extractStridedSliceOp)
1895 return Value();
1896
1897 // 0-D vectors not supported.
1898 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1899 if (hasZeroDimVectors(extractStridedSliceOp))
1900 return Value();
1901
1902 // Return if 'extractStridedSliceOp' has non-unit strides.
1903 if (extractStridedSliceOp.hasNonUnitStrides())
1904 return Value();
1905
1906 // Trim offsets for dimensions fully extracted.
1907 auto sliceOffsets =
1908 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1909 while (!sliceOffsets.empty()) {
1910 size_t lastOffset = sliceOffsets.size() - 1;
1911 if (sliceOffsets.back() != 0 ||
1912 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1913 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1914 break;
1915 sliceOffsets.pop_back();
1916 }
1917 unsigned destinationRank = 0;
1918 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1919 destinationRank = vecType.getRank();
1920 // The dimensions of the result need to be untouched by the
1921 // extractStridedSlice op.
1922 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1923 sliceOffsets.size())
1924 return Value();
1925
1926 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1927 assert(extractedPos.size() >= sliceOffsets.size());
1928 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1929 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1930 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1931
1932 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1933 OpBuilder b(extractOp.getContext());
1934 extractOp.setStaticPosition(extractedPos);
1935 return extractOp.getResult();
1936}
1937
1938/// Fold extract_op fed from a chain of insertStridedSlice ops.
1939static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1940 // TODO: Canonicalization for dynamic position not implemented yet.
1941 if (extractOp.hasDynamicPosition())
1942 return Value();
1943
1944 int64_t destinationRank =
1945 llvm::isa<VectorType>(extractOp.getType())
1946 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1947 : 0;
1948 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1949 if (!insertOp)
1950 return Value();
1951
1952 // 0-D vectors not supported.
1953 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1954 if (hasZeroDimVectors(insertOp))
1955 return Value();
1956
1957 while (insertOp) {
1958 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1959 insertOp.getSourceVectorType().getRank();
1960 if (destinationRank > insertOp.getSourceVectorType().getRank())
1961 return Value();
1962 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1963 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1964
1965 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1966 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1967 }))
1968 return Value();
1969 bool disjoint = false;
1970 SmallVector<int64_t, 4> offsetDiffs;
1971 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1972 int64_t start = insertOffsets[dim];
1973 int64_t size =
1974 (dim < insertRankDiff)
1975 ? 1
1976 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1977 int64_t end = start + size;
1978 int64_t offset = extractOffsets[dim];
1979 // Check if the start of the extract offset is in the interval inserted.
1980 if (start <= offset && offset < end) {
1981 if (dim >= insertRankDiff)
1982 offsetDiffs.push_back(offset - start);
1983 continue;
1984 }
1985 disjoint = true;
1986 break;
1987 }
1988 // The extract element chunk overlap with the vector inserted.
1989 if (!disjoint) {
1990 // If any of the inner dimensions are only partially inserted we have a
1991 // partial overlap.
1992 int64_t srcRankDiff =
1993 insertOp.getSourceVectorType().getRank() - destinationRank;
1994 for (int64_t i = 0; i < destinationRank; i++) {
1995 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1996 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1997 insertRankDiff))
1998 return Value();
1999 }
2000 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2001 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2002 OpBuilder b(extractOp.getContext());
2003 extractOp.setStaticPosition(offsetDiffs);
2004 return extractOp.getResult();
2005 }
2006 // If the chunk extracted is disjoint from the chunk inserted, keep
2007 // looking in the insert chain.
2008 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2009 }
2010 return Value();
2011}
2012
2013/// Try to fold the extraction of a scalar from a vector defined by
2014/// vector.from_elements. E.g.:
2015///
2016/// %0 = vector.from_elements %a, %b : vector<2xf32>
2017/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2018/// ==> fold to %a
2019static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2020 // Dynamic extractions cannot be folded.
2021 if (extractOp.hasDynamicPosition())
2022 return {};
2023
2024 // Look for extract(from_elements).
2025 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2026 if (!fromElementsOp)
2027 return {};
2028
2029 // Scalable vectors are not supported.
2030 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2031 if (vecType.isScalable())
2032 return {};
2033
2034 // Only extractions of scalars are supported.
2035 int64_t rank = vecType.getRank();
2036 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2037 if (extractOp.getType() != vecType.getElementType())
2038 return {};
2039 assert(static_cast<int64_t>(indices.size()) == rank &&
2040 "unexpected number of indices");
2041
2042 // Compute flattened/linearized index and fold to operand.
2043 int flatIndex = 0;
2044 int stride = 1;
2045 for (int i = rank - 1; i >= 0; --i) {
2046 flatIndex += indices[i] * stride;
2047 stride *= vecType.getDimSize(i);
2048 }
2049 return fromElementsOp.getElements()[flatIndex];
2050}
2051
2052/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2053/// then fold it.
2054template <typename OpType, typename AdaptorType>
2055static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2056 SmallVectorImpl<Value> &operands) {
2057 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2058 OperandRange dynamicPosition = op.getDynamicPosition();
2059 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2061 if constexpr (std::is_same_v<OpType, ExtractOp>)
2062 vectorShape = op.getSourceVectorType().getShape();
2063 else
2064 vectorShape = op.getDestVectorType().getShape();
2065
2066 // If the dynamic operands is empty, it is returned directly.
2067 if (!dynamicPosition.size())
2068 return {};
2069
2070 // `index` is used to iterate over the `dynamicPosition`.
2071 unsigned index = 0;
2072
2073 // `opChange` is a flag. If it is true, it means to update `op` in place.
2074 bool opChange = false;
2075 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2076 if (ShapedType::isStatic(staticPosition[i]))
2077 continue;
2078 Attribute positionAttr = dynamicPositionAttr[index];
2079 Value position = dynamicPosition[index++];
2080 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2081 int64_t value = attr.getInt();
2082 // Do not fold if the value is out of bounds (-1 signifies a poison
2083 // value rather than OOB index).
2084 if (value >= -1 && value < vectorShape[i]) {
2085 staticPosition[i] = attr.getInt();
2086 opChange = true;
2087 continue;
2088 }
2089 }
2090 operands.push_back(position);
2091 }
2092
2093 if (opChange) {
2094 op.setStaticPosition(staticPosition);
2095 op.getOperation()->setOperands(operands);
2096 // Return the original result to indicate an in-place folding happened.
2097 return op.getResult();
2098 }
2099 return {};
2100}
2101
2102/// Fold an insert or extract operation into an poison value when a poison index
2103/// is found at any dimension of the static position.
2105 ArrayRef<int64_t> staticPos,
2106 int64_t poisonVal) {
2107 if (!is_contained(staticPos, poisonVal))
2108 return {};
2109
2110 return ub::PoisonAttr::get(context);
2111}
2112
2113/// Fold a vector extract from is a poison source.
2115 if (matchPattern(srcAttr, ub::m_Poison()))
2116 return srcAttr;
2117
2118 return {};
2119}
2120
2121/// Fold a vector extract extracting from a DenseElementsAttr.
2123 Attribute srcAttr) {
2124 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2125 if (!denseAttr) {
2126 return {};
2127 }
2128
2129 if (denseAttr.isSplat()) {
2130 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2131 if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2132 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2133 return newAttr;
2134 }
2135
2136 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2137 if (vecTy.isScalable())
2138 return {};
2139
2140 if (extractOp.hasDynamicPosition()) {
2141 return {};
2142 }
2143
2144 // Materializing subsets of a large constant array can generally lead to
2145 // explosion in IR size because of different combination of subsets that
2146 // can exist. However, vector.extract is a restricted form of subset
2147 // extract where you can only extract non-overlapping (or the same) subset for
2148 // a given rank of the subset. Because of this property, the IR size can only
2149 // increase at most by `rank * size(array)` from a single constant array being
2150 // extracted by multiple extracts.
2151
2152 // Calculate the linearized position of the continuous chunk of elements to
2153 // extract.
2154 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2155 copy(extractOp.getStaticPosition(), completePositions.begin());
2156 int64_t startPos =
2157 linearize(completePositions, computeStrides(vecTy.getShape()));
2158 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2159
2160 TypedAttr newAttr;
2161 if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2162 SmallVector<Attribute> elementValues(
2163 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2164 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2165 } else {
2166 newAttr = *denseValuesBegin;
2167 }
2168
2169 return newAttr;
2170}
2171
2172OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2173 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2174 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2175 // mismatch).
2176 if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2177 return getSource();
2178 if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2179 return res;
2180 // Fold `arith.constant` indices into the `vector.extract` operation.
2181 // Do not stop here as this fold may enable subsequent folds that require
2182 // constant indices.
2183 SmallVector<Value> operands = {getSource()};
2184 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2185
2186 if (auto res = foldPoisonIndexInsertExtractOp(
2187 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2188 return res;
2189 if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2190 return res;
2191 if (succeeded(foldExtractOpFromExtractChain(*this)))
2192 return getResult();
2193 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2194 return res;
2195 if (auto res = foldExtractFromBroadcast(*this))
2196 return res;
2197 if (auto res = foldExtractFromShuffle(*this))
2198 return res;
2199 if (auto res = foldExtractFromShapeCast(*this))
2200 return res;
2201 if (auto val = foldExtractFromExtractStrided(*this))
2202 return val;
2203 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2204 return val;
2205 if (auto val = foldScalarExtractFromFromElements(*this))
2206 return val;
2207
2208 return inplaceFolded;
2209}
2210
2211namespace {
2212
2213// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2214class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2215public:
2216 using Base::Base;
2217
2218 LogicalResult matchAndRewrite(ExtractOp extractOp,
2219 PatternRewriter &rewriter) const override {
2220
2221 Operation *defOp = extractOp.getSource().getDefiningOp();
2222 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2223 if (!defOp || !isBroadcastLike(defOp) || !outType)
2224 return failure();
2225
2226 Value source = defOp->getOperand(0);
2227 if (isBroadcastableTo(source.getType(), outType) !=
2228 BroadcastableToResult::Success)
2229 return failure();
2230
2231 rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2232 return success();
2233 }
2234};
2235
2236// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2237class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2238public:
2239 using Base::Base;
2240
2241 LogicalResult matchAndRewrite(ExtractOp extractOp,
2242 PatternRewriter &rewriter) const override {
2243 auto createMaskOp =
2244 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2245 if (!createMaskOp)
2246 return failure();
2247
2248 VectorType extractedMaskType =
2249 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2250
2251 if (!extractedMaskType)
2252 return failure();
2253
2254 auto maskOperands = createMaskOp.getOperands();
2255 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2256 VectorType maskType = createMaskOp.getVectorType();
2257
2258 bool containsUnknownDims = false;
2259 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2260
2261 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2262 dimIdx++) {
2263 int64_t pos = extractOpPos[dimIdx];
2264 Value operand = maskOperands[dimIdx];
2265 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2266 if (!constantOp) {
2267 // Bounds of this dim unknown.
2268 containsUnknownDims = true;
2269 continue;
2270 }
2271
2272 int64_t createMaskBound =
2273 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2274
2275 if (pos != ShapedType::kDynamic) {
2276 // If any position is outside the range from the `create_mask`, then the
2277 // extracted mask will be all-false.
2278 allFalse |= pos >= createMaskBound;
2279 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2280 // This dim is not all-true and since this is a dynamic index we don't
2281 // know if the extraction is within the true or false region.
2282 // Note: Zero dims have already handled via getMaskFormat().
2283 containsUnknownDims = true;
2284 }
2285 }
2286
2287 if (allFalse) {
2288 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2289 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2290 } else if (!containsUnknownDims) {
2291 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2292 extractOp, extractedMaskType,
2293 maskOperands.drop_front(extractOpPos.size()));
2294 } else {
2295 return failure();
2296 }
2297 return success();
2298 }
2299};
2300
2301// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
2302class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
2303public:
2304 using Base::Base;
2305
2306 LogicalResult matchAndRewrite(ExtractOp extractOp,
2307 PatternRewriter &rewriter) const override {
2308 auto constantMaskOp =
2309 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2310 if (!constantMaskOp)
2311 return failure();
2312
2313 Type resultType = extractOp.getResult().getType();
2314 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2315
2316 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2317 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2318
2319 VectorType maskType = constantMaskOp.getVectorType();
2320
2321 // Check if any extracted position is outside the mask bounds.
2322 for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2323 int64_t pos = extractOpPos[dimIdx];
2324 if (pos == ShapedType::kDynamic) {
2325 // If the dim is all-true, a dynamic index is fine — any position
2326 // is within the masked region.
2327 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2328 continue;
2329 // Otherwise we don't know if the position is inside or outside of
2330 // the masked area, so bail out.
2331 return failure();
2332 }
2333
2334 // If the position is statically outside of the masked area, the result
2335 // will be all-false.
2336 if (pos >= maskDimSizes[dimIdx]) {
2337 if (extractedMaskType) {
2338 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2339 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2340 } else {
2341 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2342 extractOp, rewriter.getIntegerAttr(resultType, false));
2343 }
2344 return success();
2345 }
2346 }
2347
2348 // All positions are within the mask bounds.
2349 if (extractedMaskType) {
2350 // Vector result: the result is a constant_mask with the remaining
2351 // dimensions.
2352 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
2353 extractOp, extractedMaskType,
2354 maskDimSizes.drop_front(extractOpPos.size()));
2355 } else {
2356 // Scalar result: all positions are within the masked region, so the
2357 // result is true.
2358 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2359 extractOp, rewriter.getIntegerAttr(resultType, true));
2360 }
2361 return success();
2362 }
2363};
2364
2365// Folds extract(shape_cast(..)) into shape_cast when the total element count
2366// does not change.
2367LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2368 PatternRewriter &rewriter) {
2369 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2370 if (!castOp)
2371 return failure();
2372
2373 VectorType sourceType = castOp.getSourceVectorType();
2374 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2375 if (!targetType)
2376 return failure();
2377
2378 if (sourceType.getNumElements() != targetType.getNumElements())
2379 return failure();
2380
2381 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2382 castOp.getSource());
2383 return success();
2384}
2385
2386/// Try to canonicalize the extraction of a subvector from a vector defined by
2387/// vector.from_elements. E.g.:
2388///
2389/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2390/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2391/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2392LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2393 PatternRewriter &rewriter) {
2394 // Dynamic positions are not supported.
2395 if (extractOp.hasDynamicPosition())
2396 return failure();
2397
2398 // Scalar extracts are handled by the folder.
2399 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2400 if (!resultType)
2401 return failure();
2402
2403 // Look for extracts from a from_elements op.
2404 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2405 if (!fromElementsOp)
2406 return failure();
2407 VectorType inputType = fromElementsOp.getType();
2408
2409 // Scalable vectors are not supported.
2410 if (resultType.isScalable() || inputType.isScalable())
2411 return failure();
2412
2413 // Compute the position of first extracted element and flatten/linearize the
2414 // position.
2415 SmallVector<int64_t> firstElementPos =
2416 llvm::to_vector(extractOp.getStaticPosition());
2417 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2418 int flatIndex = 0;
2419 int stride = 1;
2420 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2421 flatIndex += firstElementPos[i] * stride;
2422 stride *= inputType.getDimSize(i);
2423 }
2424
2425 // Replace the op with a smaller from_elements op.
2426 rewriter.replaceOpWithNewOp<FromElementsOp>(
2427 extractOp, resultType,
2428 fromElementsOp.getElements().slice(flatIndex,
2429 resultType.getNumElements()));
2430 return success();
2431}
2432
2433/// Replace `vector.extract` with `vector.shape_cast`.
2434///
2435/// BEFORE:
2436/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2437/// AFTER:
2438/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2439///
2440/// The canonical form of vector operations that reshape vectors is shape_cast.
2441struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2442 using Base::Base;
2443 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2444 PatternRewriter &rewriter) const override {
2445 VectorType sourceType = extractOp.getSourceVectorType();
2446 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2447 if (!outType)
2448 return failure();
2449
2450 if (sourceType.getNumElements() != outType.getNumElements())
2451 return rewriter.notifyMatchFailure(
2452 extractOp, "extract to vector with fewer elements");
2453
2454 // Negative values in `position` means that the extacted value is poison.
2455 // There is a vector.extract folder for this.
2456 if (llvm::any_of(extractOp.getMixedPosition(),
2457 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2458 return rewriter.notifyMatchFailure(extractOp,
2459 "leaving for extract poison folder");
2460
2461 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2462 extractOp.getSource());
2463
2464 return success();
2465 }
2466};
2467
2468} // namespace
2469
2470void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2471 MLIRContext *context) {
2472 results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2473 ExtractOpFromConstantMask, ExtractToShapeCast>(context);
2474 results.add(foldExtractFromShapeCastToShapeCast);
2475 results.add(foldExtractFromFromElements);
2476}
2477
2479 SmallVectorImpl<int64_t> &results) {
2480 for (auto attr : arrayAttr)
2481 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2482}
2483
2484//===----------------------------------------------------------------------===//
2485// FmaOp
2486//===----------------------------------------------------------------------===//
2487
2488std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2489 return llvm::to_vector<4>(getVectorType().getShape());
2490}
2491
2492//===----------------------------------------------------------------------===//
2493// ToElementsOp
2494//===----------------------------------------------------------------------===//
2495
2496/// Returns true if all the `operands` are defined by `defOp`.
2497/// Otherwise, returns false.
2498static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2499 if (operands.empty())
2500 return false;
2501
2502 return llvm::all_of(operands, [&](Value operand) {
2503 Operation *currentDef = operand.getDefiningOp();
2504 return currentDef == defOp;
2505 });
2506}
2507
2508/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2509/// (%e0, %e1, ...). For example:
2510///
2511/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2512/// %1:3 = vector.to_elements %0 : vector<3xf32>
2513/// user_op %1#0, %1#1, %1#2
2514///
2515/// becomes:
2516///
2517/// user_op %a, %b, %c
2518///
2519static LogicalResult
2520foldToElementsFromElements(ToElementsOp toElementsOp,
2522 auto fromElementsOp =
2523 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2524 if (!fromElementsOp)
2525 return failure();
2526
2527 llvm::append_range(results, fromElementsOp.getElements());
2528 return success();
2529}
2530
2531/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2532///
2533/// Example:
2534/// %b = vector.broadcast %x : i32 to vector<3xf32>
2535/// %e:3 = vector.to_elements %b : vector<3xf32>
2536/// user_op %e#0, %e#1, %e#2
2537/// becomes:
2538/// user_op %x, %x, %x
2539///
2540/// The vector source case is handled by a canonicalization pattern.
2541static LogicalResult
2542foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2544 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2545 if (!bcastOp)
2546 return failure();
2547 // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2548 if (isa<VectorType>(bcastOp.getSource().getType()))
2549 return failure();
2550
2551 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2552
2553 Value scalar = bcastOp.getSource();
2554 results.assign(resultVecType.getNumElements(), scalar);
2555 return success();
2556}
2557
2558LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2559 SmallVectorImpl<OpFoldResult> &results) {
2560 if (succeeded(foldToElementsFromElements(*this, results)))
2561 return success();
2562
2563 // Y = ToElements(ShapeCast(X)) -> Y = ToElements(X)
2564 if (auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2565 setOperand(shapeCast.getSource());
2566 return success();
2567 }
2568
2569 return foldToElementsOfBroadcast(*this, results);
2570}
2571
2572LogicalResult
2573ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2574 ToElementsOp::Adaptor adaptor,
2575 SmallVectorImpl<Type> &inferredReturnTypes) {
2576 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2577 Type elType = vecType.getElementType();
2578 inferredReturnTypes.append(vecType.getNumElements(), elType);
2579 return success();
2580}
2581
2582/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2583/// vector.
2584/// - Build `vector.to_elements %v` and remap each destination element to the
2585/// corresponding source element using broadcast rules (match or 1 →
2586/// replicate).
2587///
2588/// Example:
2589/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2590/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2591/// becomes:
2592/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2593/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2594/// // %src_elems#1, %src_elems#0, %src_elems#1
2595struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2596 using Base::Base;
2597
2598 LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2599 PatternRewriter &rewriter) const override {
2600 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2601 if (!bcastOp)
2602 return failure();
2603
2604 // Only handle broadcasts from a vector source here.
2605 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2606 if (!srcType)
2607 return failure();
2608
2609 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2610
2611 ArrayRef<int64_t> dstShape = dstType.getShape();
2612 ArrayRef<int64_t> srcShape = srcType.getShape();
2613
2614 int64_t dstRank = dstShape.size();
2615 int64_t srcRank = srcShape.size();
2616
2617 // Create elements for the broadcast source vector.
2618 auto srcElems = vector::ToElementsOp::create(
2619 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2620
2621 int64_t dstCount = llvm::product_of(dstShape);
2622
2623 SmallVector<Value> replacements;
2624 replacements.reserve(dstCount);
2625
2626 // For each element of the destination, determine which element of the
2627 // source should be used. We walk all destination positions using a single
2628 // counter, decode it into per-dimension indices, then build the matching
2629 // source position: use the same index where sizes match, and use 0 where
2630 // the source size is 1 (replication). This mapping is needed so we can
2631 // replace each result of to_elements with the corresponding element from
2632 // the broadcast source.
2633 // Inner-dimension stretch example:
2634 // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2635 // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2636 // becomes:
2637 // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2638 // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2639 // // %src_elems#1, %src_elems#0, %src_elems#1,
2640 // // %src_elems#2, %src_elems#3, %src_elems#2,
2641 // // %src_elems#3, %src_elems#2, %src_elems#3
2642
2643 // Row-major strides for the destination shape.
2644 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2645 // Row-major strides for the source shape.
2646 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2647 SmallVector<int64_t> dstIdx(dstRank);
2648 SmallVector<int64_t> srcIdx(srcRank);
2649 for (int64_t lin = 0; lin < dstCount; ++lin) {
2650 // Convert linear destination index to per-dimension indices.
2651 dstIdx = delinearize(lin, dstStrides);
2652 for (int64_t k = 0; k < srcRank; ++k)
2653 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2654 // Convert per-dimension source indices back to a linear index.
2655 int64_t srcLin = linearize(srcIdx, srcStrides);
2656 replacements.push_back(srcElems.getResult(srcLin));
2657 }
2658
2659 rewriter.replaceOp(toElementsOp, replacements);
2660 return success();
2661 }
2662};
2663
2664void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2665 MLIRContext *context) {
2666 results.add<ToElementsOfBroadcast>(context);
2667}
2668
2669//===----------------------------------------------------------------------===//
2670// FromElementsOp
2671//===----------------------------------------------------------------------===//
2672
2673/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2674///
2675/// Case #1: Input and output vectors are the same.
2676///
2677/// %0:3 = vector.to_elements %a : vector<3xf32>
2678/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2679/// user_op %1
2680///
2681/// becomes:
2682///
2683/// user_op %a
2684///
2685static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2686 OperandRange fromElemsOperands = fromElementsOp.getElements();
2687 if (fromElemsOperands.empty())
2688 return {};
2689
2690 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2691 if (!toElementsOp)
2692 return {};
2693
2694 if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2695 return {};
2696
2697 // Case #1: Input and output vectors are the same. Forward the input vector.
2698 Value toElementsInput = toElementsOp.getSource();
2699 if (fromElementsOp.getType() == toElementsInput.getType() &&
2700 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2701 return toElementsInput;
2702 }
2703
2704 // TODO: Support cases with different input and output shapes and different
2705 // number of elements.
2706
2707 return {};
2708}
2709
2710/// Fold vector.from_elements to a constant when all operands are constants.
2711/// Example:
2712/// %c1 = arith.constant 1 : i32
2713/// %c2 = arith.constant 2 : i32
2714/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2715/// =>
2716/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2717///
2718static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2719 ArrayRef<Attribute> elements) {
2720 // Check for null or poison attributes before any processing.
2721 if (llvm::any_of(elements, [](Attribute attr) {
2722 return !attr || matchPattern(attr, ub::m_Poison());
2723 }))
2724 return {};
2725
2726 // DenseElementsAttr only supports int/index/float/complex types.
2727 auto destVecType = fromElementsOp.getDest().getType();
2728 auto destEltType = destVecType.getElementType();
2729 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2730 return {};
2731
2732 // Constant attributes might have a different type than the return type.
2733 // Convert them before creating the dense elements attribute.
2734 auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2735 return convertNumericAttr(attr, destEltType);
2736 });
2737
2738 return DenseElementsAttr::get(destVecType, convertedElements);
2739}
2740
2741OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2742 if (auto res = foldFromElementsToElements(*this))
2743 return res;
2744 if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2745 return res;
2746
2747 return {};
2748}
2749
2750/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2751/// same. Example:
2752/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2753/// =>
2754/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2755static LogicalResult
2756rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2757 PatternRewriter &rewriter) {
2758 if (!llvm::all_equal(fromElementsOp.getElements()))
2759 return failure();
2760 rewriter.replaceOpWithNewOp<BroadcastOp>(
2761 fromElementsOp, fromElementsOp.getType(),
2762 fromElementsOp.getElements().front());
2763 return success();
2764}
2765
2766/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2767/// on a single extract. Example:
2768/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2769/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2770/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2771///
2772/// becomes
2773/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2774/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2775///
2776/// The requirements for this to be valid are
2777///
2778/// i) The elements are extracted from the same vector (%source).
2779///
2780/// ii) The elements form a suffix of %source. Specifically, the number
2781/// of elements is the same as the product of the last N dimension sizes
2782/// of %source, for some N.
2783///
2784/// iii) The elements are extracted contiguously in ascending order.
2785
2786class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2787
2788 using Base::Base;
2789
2790 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2791 PatternRewriter &rewriter) const override {
2792
2793 // Handled by `rewriteFromElementsAsBroadcast`.
2794 if (fromElements.getType().getNumElements() == 1)
2795 return failure();
2796
2797 // The common source that all elements are extracted from, if one exists.
2799 // The position of the combined extract operation, if one is created.
2800 ArrayRef<int64_t> combinedPosition;
2801 // The expected index of extraction of the current element in the loop, if
2802 // elements are extracted contiguously in ascending order.
2803 SmallVector<int64_t> expectedPosition;
2804
2805 for (auto [insertIndex, element] :
2806 llvm::enumerate(fromElements.getElements())) {
2807
2808 // Check that the element is from a vector.extract operation.
2809 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2810 if (!extractOp) {
2811 return rewriter.notifyMatchFailure(fromElements,
2812 "element not from vector.extract");
2813 }
2814
2815 // Check condition (i) by checking that all elements have the same source
2816 // as the first element.
2817 if (insertIndex == 0) {
2818 source = extractOp.getSource();
2819 } else if (extractOp.getSource() != source) {
2820 return rewriter.notifyMatchFailure(fromElements,
2821 "element from different vector");
2822 }
2823
2824 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2825 int64_t rank = position.size();
2826 assert(rank == source.getType().getRank() &&
2827 "scalar extract must have full rank position");
2828
2829 // Check condition (ii) by checking that the position that the first
2830 // element is extracted from has sufficient trailing 0s. For example, in
2831 //
2832 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2833 // [...]
2834 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2835 //
2836 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2837 // elements, which is the number of elements of %n, so this is valid.
2838 if (insertIndex == 0) {
2839 const int64_t numElms = fromElements.getType().getNumElements();
2840 int64_t numSuffixElms = 1;
2841 int64_t index = rank;
2842 while (index > 0 && position[index - 1] == 0 &&
2843 numSuffixElms < numElms) {
2844 numSuffixElms *= source.getType().getDimSize(index - 1);
2845 --index;
2846 }
2847 if (numSuffixElms != numElms) {
2848 return rewriter.notifyMatchFailure(
2849 fromElements, "elements do not form a suffix of source");
2850 }
2851 expectedPosition = llvm::to_vector(position);
2852 combinedPosition = position.drop_back(rank - index);
2853 }
2854
2855 // Check condition (iii).
2856 else if (expectedPosition != position) {
2857 return rewriter.notifyMatchFailure(
2858 fromElements, "elements not in ascending order (static order)");
2859 }
2860 increment(expectedPosition, source.getType().getShape());
2861 }
2862
2863 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2864 fromElements.getLoc(), source, combinedPosition);
2865
2866 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2867 fromElements, fromElements.getType(), extracted);
2868
2869 return success();
2870 }
2871
2872 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2873 static void increment(MutableArrayRef<int64_t> indices,
2875 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2876 indices[dim] += 1;
2877 if (indices[dim] < shape[dim])
2878 break;
2879 indices[dim] = 0;
2880 }
2881 }
2882};
2883
2884void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2885 MLIRContext *context) {
2887 results.add<FromElementsToShapeCast>(context);
2888}
2889
2890//===----------------------------------------------------------------------===//
2891// BroadcastOp
2892//===----------------------------------------------------------------------===//
2893
2894void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2895 SetIntRangeFn setResultRanges) {
2896 setResultRanges(getResult(), argRanges.front());
2897}
2898
2899std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2900 return llvm::to_vector<4>(getResultVectorType().getShape());
2901}
2902
2903/// Return the dimensions of the result vector that were formerly ones in the
2904/// source tensor and thus correspond to "dim-1" broadcasting.
2905static llvm::SetVector<int64_t>
2907 ArrayRef<int64_t> dstShape) {
2908 int64_t rankDiff = dstShape.size() - srcShape.size();
2909 int64_t dstDim = rankDiff;
2911 for (auto [s1, s2] :
2912 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2913 if (s1 != s2) {
2914 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2915 res.insert(dstDim);
2916 }
2917 ++dstDim;
2918 }
2919 return res;
2920}
2921
2922llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2923 // Scalar broadcast is without any unit dim broadcast.
2924 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2925 if (!srcVectorType)
2926 return {};
2927 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2928 getResultVectorType().getShape());
2929}
2930
2931/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2932/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2933/// This requires (and asserts) that the broadcast is free of "dim-1"
2934/// broadcasting.
2935/// Since vector.broadcast only allows expanding leading dimensions, an extra
2936/// vector.transpose may be inserted to make the broadcast possible.
2937/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2938/// the helper will assert. This means:
2939/// 1. `dstShape` must not be empty.
2940/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2941/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2942// must match the `value` shape.
2943Value BroadcastOp::createOrFoldBroadcastOp(
2944 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2945 const llvm::SetVector<int64_t> &broadcastedDims) {
2946 assert(!dstShape.empty() && "unexpected empty dst shape");
2947
2948 // Well-formedness check.
2949 SmallVector<int64_t> checkShape;
2950 for (int i = 0, e = dstShape.size(); i < e; ++i) {
2951 if (broadcastedDims.contains(i))
2952 continue;
2953 checkShape.push_back(dstShape[i]);
2954 }
2955 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2956 "ill-formed broadcastedDims contains values not confined to "
2957 "destVectorShape");
2958
2959 Location loc = value.getLoc();
2960 Type elementType = getElementTypeOrSelf(value.getType());
2961 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2962 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2963
2964 // Step 2. If scalar -> dstShape broadcast, just do it.
2965 if (!srcVectorType) {
2966 assert(checkShape.empty() &&
2967 "ill-formed createOrFoldBroadcastOp arguments");
2968 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2969 }
2970
2971 assert(srcVectorType.getShape().equals(checkShape) &&
2972 "ill-formed createOrFoldBroadcastOp arguments");
2973
2974 // Step 3. Since vector.broadcast only allows creating leading dims,
2975 // vector -> dstShape broadcast may require a transpose.
2976 // Traverse the dims in order and construct:
2977 // 1. The leading entries of the broadcastShape that is guaranteed to be
2978 // achievable by a simple broadcast.
2979 // 2. The induced permutation for the subsequent vector.transpose that will
2980 // bring us from `broadcastShape` back to he desired `dstShape`.
2981 // If the induced permutation is not the identity, create a vector.transpose.
2982 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2983 broadcastShape.reserve(dstShape.size());
2984 // Consider the example:
2985 // srcShape = 2x4
2986 // dstShape = 1x2x3x4x5
2987 // broadcastedDims = [0, 2, 4]
2988 //
2989 // We want to build:
2990 // broadcastShape = 1x3x5x2x4
2991 // permutation = [0, 2, 4, 1, 3]
2992 // ---V--- -----V-----
2993 // leading broadcast part src shape part
2994 //
2995 // Note that the trailing dims of broadcastShape are exactly the srcShape
2996 // by construction.
2997 // nextSrcShapeDim is used to keep track of where in the permutation the
2998 // "src shape part" occurs.
2999 int64_t nextSrcShapeDim = broadcastedDims.size();
3000 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
3001 if (broadcastedDims.contains(i)) {
3002 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
3003 // bring it to the head of the broadcastShape.
3004 // It will need to be permuted back from `broadcastShape.size() - 1` into
3005 // position `i`.
3006 broadcastShape.push_back(dstShape[i]);
3007 permutation[i] = broadcastShape.size() - 1;
3008 } else {
3009 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
3010 // shape and needs to be permuted into position `i`.
3011 // Don't touch `broadcastShape` here, the whole srcShape will be
3012 // appended after.
3013 permutation[i] = nextSrcShapeDim++;
3014 }
3015 }
3016 // 3.c. Append the srcShape.
3017 llvm::append_range(broadcastShape, srcVectorType.getShape());
3018
3019 // Ensure there are no "dim-1" broadcasts.
3020 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
3021 .empty() &&
3022 "unexpected \"dim-1\" broadcast");
3023
3024 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3025 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
3026 vector::BroadcastableToResult::Success &&
3027 "must be broadcastable");
3028 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3029 // Step 4. If we find any dimension that indeed needs to be permuted,
3030 // immediately return a new vector.transpose.
3031 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3032 if (permutation[i] != i)
3033 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3034 // Otherwise return res.
3035 return res;
3036}
3037
3039 Type srcType, VectorType dstVectorType,
3040 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3041 // Broadcast scalar to vector of the same element type.
3042 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3043 srcType == getElementTypeOrSelf(dstVectorType))
3045 // From now on, only vectors broadcast.
3046 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3047 if (!srcVectorType)
3049
3050 int64_t srcRank = srcVectorType.getRank();
3051 int64_t dstRank = dstVectorType.getRank();
3052 if (srcRank > dstRank)
3054 // Source has an exact match or singleton value for all trailing dimensions
3055 // (all leading dimensions are simply duplicated).
3056 int64_t lead = dstRank - srcRank;
3057 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3058 // Have mismatching dims (in the sense of vector.broadcast semantics) been
3059 // encountered?
3060 bool foundMismatchingDims = false;
3061
3062 // Check fixed-width dims.
3063 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3064 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3065 if (srcDim != 1 && srcDim != dstDim)
3066 foundMismatchingDims = true;
3067
3068 // Check scalable flags.
3069 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3070 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3071 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3072 // 1 -> [N] is fine, everything else should be rejected when mixing
3073 // fixed-width and scalable dims
3074 (srcDimScalableFlag != dstDimScalableFlag &&
3075 (srcDim != 1 || srcDimScalableFlag)))
3076 foundMismatchingDims = true;
3077
3078 if (foundMismatchingDims) {
3079 if (mismatchingDims != nullptr) {
3080 mismatchingDims->first.dim = srcDim;
3081 mismatchingDims->first.isScalable = srcDimScalableFlag;
3082
3083 mismatchingDims->second.dim = dstDim;
3084 mismatchingDims->second.isScalable = dstDimScalableFlag;
3085 }
3087 }
3088 }
3089
3091}
3092
3093LogicalResult BroadcastOp::verify() {
3094 std::pair<VectorDim, VectorDim> mismatchingDims;
3096 getSourceType(), getResultVectorType(), &mismatchingDims);
3098 return success();
3100 return emitOpError("source rank higher than destination rank");
3102 return emitOpError("dimension mismatch (")
3103 << (mismatchingDims.first.isScalable ? "[" : "")
3104 << mismatchingDims.first.dim
3105 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
3106 << (mismatchingDims.second.isScalable ? "[" : "")
3107 << mismatchingDims.second.dim
3108 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
3109 }
3111 return emitOpError("source type is not a vector");
3112 llvm_unreachable("unexpected vector.broadcast op error");
3113}
3114
3115// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
3116// with broadcast's result type and shape_cast only adds or removes ones in the
3117// leading dimensions.
3118static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
3119 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3120 if (!srcShapeCast)
3121 return failure();
3122
3123 VectorType srcType = srcShapeCast.getSourceVectorType();
3124 VectorType destType = broadcastOp.getResultVectorType();
3125 // Check type compatibility.
3126 if (vector::isBroadcastableTo(srcType, destType) !=
3128 return failure();
3129
3130 ArrayRef<int64_t> srcShape = srcType.getShape();
3131 ArrayRef<int64_t> shapecastShape =
3132 srcShapeCast.getResultVectorType().getShape();
3133 // Trailing dimensions should be the same if shape_cast only alters the
3134 // leading dimensions.
3135 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3136 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3137 shapecastShape.take_back(numTrailingDims)))
3138 return failure();
3139
3140 assert(all_of(srcShape.drop_back(numTrailingDims),
3141 [](int64_t E) { return E == 1; }) &&
3142 all_of(shapecastShape.drop_back(numTrailingDims),
3143 [](int64_t E) { return E == 1; }) &&
3144 "ill-formed shape_cast");
3145
3146 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3147 return success();
3148}
3149
3150OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3151 if (getSourceType() == getResultVectorType())
3152 return getSource();
3153 if (succeeded(foldBroadcastOfShapeCast(*this)))
3154 return getResult();
3155
3156 if (!adaptor.getSource())
3157 return {};
3158 auto vectorType = getResultVectorType();
3159 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3160 if (vectorType.getElementType() != attr.getType())
3161 return {};
3162 return DenseElementsAttr::get(vectorType, attr);
3163 }
3164 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3165 if (vectorType.getElementType() != attr.getType())
3166 return {};
3167 return DenseElementsAttr::get(vectorType, attr);
3168 }
3169 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3170 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3171 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
3172 return ub::PoisonAttr::get(getContext());
3173 return {};
3174}
3175
3176namespace {
3177
3178// Fold broadcast1(broadcast2(x)) into broadcast1(x).
3179struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3180 using Base::Base;
3181
3182 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3183 PatternRewriter &rewriter) const override {
3184 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3185 if (!srcBroadcast)
3186 return failure();
3187 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3188 broadcastOp.getResultVectorType(),
3189 srcBroadcast.getSource());
3190 return success();
3191 }
3192};
3193
3194/// Replace `vector.broadcast` with `vector.shape_cast`.
3195///
3196/// BEFORE:
3197/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3198/// AFTER:
3199/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3200///
3201/// The canonical form of vector operations that reshape vectors is shape_cast.
3202struct BroadcastToShapeCast final
3203 : public OpRewritePattern<vector::BroadcastOp> {
3204 using Base::Base;
3205 LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
3206 PatternRewriter &rewriter) const override {
3207
3208 auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
3209 if (!sourceType) {
3210 return rewriter.notifyMatchFailure(
3211 broadcast, "source is a scalar, shape_cast doesn't support scalar");
3212 }
3213
3214 VectorType outType = broadcast.getType();
3215 if (sourceType.getNumElements() != outType.getNumElements()) {
3216 return rewriter.notifyMatchFailure(
3217 broadcast, "broadcast to a greater number of elements");
3218 }
3219
3220 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
3221 broadcast.getSource());
3222 return success();
3223 }
3224};
3225} // namespace
3226
3227void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3228 MLIRContext *context) {
3229 results.add<BroadcastFolder, BroadcastToShapeCast>(context);
3230}
3231
3232//===----------------------------------------------------------------------===//
3233// ShuffleOp
3234//===----------------------------------------------------------------------===//
3235
3236LogicalResult ShuffleOp::verify() {
3237 VectorType resultType = getResultVectorType();
3238 VectorType v1Type = getV1VectorType();
3239 VectorType v2Type = getV2VectorType();
3240 // Verify ranks.
3241 int64_t resRank = resultType.getRank();
3242 int64_t v1Rank = v1Type.getRank();
3243 int64_t v2Rank = v2Type.getRank();
3244 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3245 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3246 if (!wellFormed0DCase && !wellFormedNDCase)
3247 return emitOpError("rank mismatch");
3248
3249 // Verify all but leading dimension sizes.
3250 for (int64_t r = 1; r < v1Rank; ++r) {
3251 int64_t resDim = resultType.getDimSize(r);
3252 int64_t v1Dim = v1Type.getDimSize(r);
3253 int64_t v2Dim = v2Type.getDimSize(r);
3254 if (resDim != v1Dim || v1Dim != v2Dim)
3255 return emitOpError("dimension mismatch");
3256 }
3257 // Verify mask length.
3258 ArrayRef<int64_t> mask = getMask();
3259 int64_t maskLength = mask.size();
3260 if (maskLength <= 0)
3261 return emitOpError("invalid mask length");
3262 if (maskLength != resultType.getDimSize(0))
3263 return emitOpError("mask length mismatch");
3264 // Verify all indices.
3265 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3266 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3267 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3268 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3269 return emitOpError("mask index #") << (idx + 1) << " out of range";
3270 }
3271 return success();
3272}
3273
3274LogicalResult
3275ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3276 ShuffleOp::Adaptor adaptor,
3277 SmallVectorImpl<Type> &inferredReturnTypes) {
3278 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3279 if (!v1Type) {
3280 return emitOptionalError(loc, "expected vector type");
3281 }
3282 auto v1Rank = v1Type.getRank();
3283 // Construct resulting type: leading dimension matches mask
3284 // length, all trailing dimensions match the operands.
3285 SmallVector<int64_t, 4> shape;
3286 shape.reserve(v1Rank);
3287 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3288 // In the 0-D case there is no trailing shape to append.
3289 if (v1Rank > 0)
3290 llvm::append_range(shape, v1Type.getShape().drop_front());
3291 inferredReturnTypes.push_back(
3292 VectorType::get(shape, v1Type.getElementType()));
3293 return success();
3294}
3295
3296template <typename T>
3297static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3298 T expected = begin;
3299 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3300 return value == expected++;
3301 });
3302}
3303
3304OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3305 auto v1Type = getV1VectorType();
3306 auto v2Type = getV2VectorType();
3307
3308 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3309 "Vector shuffle does not support scalable vectors");
3310
3311 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3312 // but must be a canonicalization into a vector.broadcast.
3313 if (v1Type.getRank() == 0)
3314 return {};
3315
3316 // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3317 auto mask = getMask();
3318 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3319 return getV1();
3320 // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3321 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3322 return getV2();
3323
3324 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3325 if (!v1Attr || !v2Attr)
3326 return {};
3327
3328 // Fold shuffle poison, poison -> poison.
3329 bool isV1Poison = matchPattern(v1Attr, ub::m_Poison());
3330 bool isV2Poison = matchPattern(v2Attr, ub::m_Poison());
3331 if (isV1Poison && isV2Poison)
3332 return ub::PoisonAttr::get(getContext());
3333
3334 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
3335 // manipulation.
3336 if (v1Type.getRank() != 1)
3337 return {};
3338
3339 // Poison input attributes need special handling as they are not
3340 // DenseElementsAttr. If an index is poison, we select the first element of
3341 // the first non-poison input.
3342 SmallVector<Attribute> v1Elements, v2Elements;
3343 Attribute poisonElement;
3344 if (!isV2Poison) {
3345 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3346 if (!v2DenseAttr)
3347 return {};
3348 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3349 poisonElement = v2Elements[0];
3350 }
3351 if (!isV1Poison) {
3352 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3353 if (!v1DenseAttr)
3354 return {};
3355 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3356 poisonElement = v1Elements[0];
3357 }
3358
3359 SmallVector<Attribute> results;
3360 int64_t v1Size = v1Type.getDimSize(0);
3361 for (int64_t maskIdx : mask) {
3362 Attribute indexedElm;
3363 // TODO: Return a partial poison vector when supported by the UB dialect.
3364 if (maskIdx == ShuffleOp::kPoisonIndex) {
3365 indexedElm = poisonElement;
3366 } else {
3367 if (maskIdx < v1Size)
3368 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3369 else
3370 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3371 }
3372
3373 results.push_back(indexedElm);
3374 }
3375
3376 return DenseElementsAttr::get(getResultVectorType(), results);
3377}
3378
3379namespace {
3380
3381// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3382// to a broadcast.
3383struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3384 using Base::Base;
3385
3386 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3387 PatternRewriter &rewriter) const override {
3388 VectorType v1VectorType = shuffleOp.getV1VectorType();
3389 ArrayRef<int64_t> mask = shuffleOp.getMask();
3390 if (v1VectorType.getRank() > 0)
3391 return failure();
3392 if (mask.size() != 1)
3393 return failure();
3394 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3395 if (mask[0] == 0)
3396 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3397 shuffleOp.getV1());
3398 else
3399 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3400 shuffleOp.getV2());
3401 return success();
3402 }
3403};
3404
3405/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3406/// vector.broadcast with a scalar operand, return the scalar value that is
3407/// splatted. Otherwise return null.
3408///
3409/// Example:
3410///
3411/// scalar_source --> vector.broadcast --> value - return scalar_source
3412static Value getScalarSplatSource(Value value) {
3413 // Block argument:
3414 Operation *defOp = value.getDefiningOp();
3415 if (!defOp)
3416 return {};
3417
3418 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3419
3420 // Not broadcast (and not splat):
3421 if (!broadcast)
3422 return {};
3423
3424 // Broadcast of a vector:
3425 if (isa<VectorType>(broadcast.getSourceType()))
3426 return {};
3427
3428 // Broadcast of a scalar:
3429 return broadcast.getSource();
3430}
3431
3432/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3433class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3434public:
3435 using Base::Base;
3436
3437 LogicalResult matchAndRewrite(ShuffleOp op,
3438 PatternRewriter &rewriter) const override {
3439 Value splat = getScalarSplatSource(op.getV1());
3440 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3441 return failure();
3442
3443 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3444 return success();
3445 }
3446};
3447
3448/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3449/// vector.interleave.
3450class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3451public:
3452 using Base::Base;
3453
3454 LogicalResult matchAndRewrite(ShuffleOp op,
3455 PatternRewriter &rewriter) const override {
3456 VectorType resultType = op.getResultVectorType();
3457 if (resultType.isScalable())
3458 return rewriter.notifyMatchFailure(
3459 op, "ShuffleOp can't represent a scalable interleave");
3460
3461 if (resultType.getRank() != 1)
3462 return rewriter.notifyMatchFailure(
3463 op, "ShuffleOp can't represent an n-D interleave");
3464
3465 VectorType sourceType = op.getV1VectorType();
3466 if (sourceType != op.getV2VectorType() ||
3467 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3468 return rewriter.notifyMatchFailure(
3469 op, "ShuffleOp types don't match an interleave");
3470 }
3471
3472 ArrayRef<int64_t> shuffleMask = op.getMask();
3473 int64_t resultVectorSize = resultType.getNumElements();
3474 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3475 int64_t maskValueA = shuffleMask[i * 2];
3476 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3477 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3478 return rewriter.notifyMatchFailure(op,
3479 "ShuffleOp mask not interleaving");
3480 }
3481
3482 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3483 return success();
3484 }
3485};
3486
3487} // namespace
3488
3489void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3490 MLIRContext *context) {
3491 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3492 context);
3493}
3494
3495//===----------------------------------------------------------------------===//
3496// InsertOp
3497//===----------------------------------------------------------------------===//
3498
3499void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3500 SetIntRangeFn setResultRanges) {
3501 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3502}
3503
3504void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3505 Value source, Value dest) {
3506 auto vectorTy = cast<VectorType>(dest.getType());
3507 build(builder, result, source, dest,
3508 SmallVector<int64_t>(vectorTy.getRank(), 0));
3509}
3510
3511void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3512 Value source, Value dest, int64_t position) {
3513 build(builder, result, source, dest, ArrayRef<int64_t>{position});
3514}
3515
3516void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3517 Value source, Value dest, OpFoldResult position) {
3518 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3519}
3520
3521void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3522 Value source, Value dest,
3523 ArrayRef<int64_t> position) {
3524 SmallVector<OpFoldResult> posVals;
3525 posVals.reserve(position.size());
3526 llvm::transform(position, std::back_inserter(posVals),
3527 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3528 build(builder, result, source, dest, posVals);
3529}
3530
3531void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3532 Value source, Value dest,
3533 ArrayRef<OpFoldResult> position) {
3534 SmallVector<int64_t> staticPos;
3535 SmallVector<Value> dynamicPos;
3536 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3537 build(builder, result, source, dest, dynamicPos,
3538 builder.getDenseI64ArrayAttr(staticPos));
3539}
3540
3541LogicalResult InsertOp::verify() {
3542 if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3543 if (srcTy.getRank() == 0)
3544 return emitError(
3545 "expected a scalar instead of a 0-d vector as the source operand");
3546
3547 SmallVector<OpFoldResult> position = getMixedPosition();
3548 auto destVectorType = getDestVectorType();
3549 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3550 return emitOpError(
3551 "expected position attribute of rank no greater than dest vector rank");
3552 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3553 if (srcVectorType &&
3554 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3555 static_cast<unsigned>(destVectorType.getRank())))
3556 return emitOpError("expected position attribute rank + source rank to "
3557 "match dest vector rank");
3558 if (!srcVectorType &&
3559 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3560 return emitOpError(
3561 "expected position attribute rank to match the dest vector rank");
3562 for (auto [idx, pos] : llvm::enumerate(position)) {
3563 if (auto attr = dyn_cast<Attribute>(pos)) {
3564 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3565 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3566 destVectorType.getDimSize(idx))) {
3567 return emitOpError("expected position attribute #")
3568 << (idx + 1)
3569 << " to be a non-negative integer smaller than the "
3570 "corresponding "
3571 "dest vector dimension";
3572 }
3573 }
3574 }
3575 return success();
3576}
3577
3578// Calculate the linearized position of the continuous chunk of elements to
3579// insert, based on the shape of the value to insert and the positions to insert
3580// at.
3581static int64_t calculateInsertPosition(VectorType destTy,
3582 ArrayRef<int64_t> positions) {
3583 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3584 assert(positions.size() <= completePositions.size() &&
3585 "positions size must be less than or equal to destTy rank");
3586 copy(positions, completePositions.begin());
3587 return linearize(completePositions, computeStrides(destTy.getShape()));
3588}
3589
3590namespace {
3591
3592// If insertOp is only inserting unit dimensions it can be transformed to a
3593// broadcast.
3594class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3595public:
3596 using Base::Base;
3597
3598 LogicalResult matchAndRewrite(InsertOp insertOp,
3599 PatternRewriter &rewriter) const override {
3600 auto srcVecType =
3601 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3602 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3603 srcVecType.getNumElements())
3604 return failure();
3605 rewriter.replaceOpWithNewOp<BroadcastOp>(
3606 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3607 return success();
3608 }
3609};
3610
3611/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3612class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3613public:
3614 using Base::Base;
3615
3616 LogicalResult matchAndRewrite(InsertOp op,
3617 PatternRewriter &rewriter) const override {
3618
3619 Value splat = getScalarSplatSource(op.getValueToStore());
3620 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3621 return failure();
3622
3623 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3624 return success();
3625 }
3626};
3627
3628/// Pattern to optimize a chain of insertions.
3629///
3630/// This pattern identifies chains of vector.insert operations that:
3631/// 1. Only insert values at static positions.
3632/// 2. Completely initialize all elements in the resulting vector.
3633/// 3. All intermediate insert operations have only one use.
3634///
3635/// When these conditions are met, the entire chain can be replaced with a
3636/// single vector.from_elements operation.
3637///
3638/// To keep this pattern simple, and avoid spending too much time on matching
3639/// fragmented insert chains, this pattern only considers the last insert op in
3640/// the chain.
3641///
3642/// Example transformation:
3643/// %poison = ub.poison : vector<2xi32>
3644/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3645/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3646/// ->
3647/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3648class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3649public:
3650 using Base::Base;
3651 LogicalResult matchAndRewrite(InsertOp op,
3652 PatternRewriter &rewriter) const override {
3653
3654 VectorType destTy = op.getDestVectorType();
3655 if (destTy.isScalable())
3656 return failure();
3657 // Ensure this is the trailing vector.insert op in a chain of inserts.
3658 for (Operation *user : op.getResult().getUsers())
3659 if (auto insertOp = dyn_cast<InsertOp>(user))
3660 if (insertOp.getDest() == op.getResult())
3661 return failure();
3662
3663 InsertOp currentOp = op;
3664 SmallVector<InsertOp> chainInsertOps;
3665 while (currentOp) {
3666 // Check cond 1: Dynamic position is not supported.
3667 if (currentOp.hasDynamicPosition())
3668 return failure();
3669
3670 chainInsertOps.push_back(currentOp);
3671 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3672 // Check cond 3: Intermediate inserts have only one use to avoid an
3673 // explosion of vectors.
3674 if (currentOp && !currentOp->hasOneUse())
3675 return failure();
3676 }
3677
3678 int64_t vectorSize = destTy.getNumElements();
3679 int64_t initializedCount = 0;
3680 SmallVector<bool> initializedDestIdxs(vectorSize, false);
3681 SmallVector<int64_t> pendingInsertPos;
3682 SmallVector<int64_t> pendingInsertSize;
3683 SmallVector<Value> pendingInsertValues;
3684
3685 for (auto insertOp : chainInsertOps) {
3686 // This pattern can do nothing with poison index.
3687 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3688 return failure();
3689
3690 // Calculate the linearized position for inserting elements.
3691 int64_t insertBeginPosition =
3692 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3693
3694 // The valueToStore operand may be a vector or a scalar. Need to handle
3695 // both cases.
3696 int64_t insertSize = 1;
3697 if (auto srcVectorType =
3698 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3699 insertSize = srcVectorType.getNumElements();
3700
3701 assert(insertBeginPosition + insertSize <= vectorSize &&
3702 "insert would overflow the vector");
3703
3704 for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3705 insertBeginPosition + insertSize)) {
3706 if (initializedDestIdxs[index])
3707 continue;
3708 initializedDestIdxs[index] = true;
3709 ++initializedCount;
3710 }
3711
3712 // Defer the creation of ops before we can make sure the pattern can
3713 // succeed.
3714 pendingInsertPos.push_back(insertBeginPosition);
3715 pendingInsertSize.push_back(insertSize);
3716 pendingInsertValues.push_back(insertOp.getValueToStore());
3717
3718 if (initializedCount == vectorSize)
3719 break;
3720 }
3721
3722 // Check cond 2: all positions must be initialized.
3723 if (initializedCount != vectorSize)
3724 return failure();
3725
3726 SmallVector<Value> elements(vectorSize);
3727 for (auto [insertBeginPosition, insertSize, valueToStore] :
3728 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3729 pendingInsertValues))) {
3730 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3731
3732 if (!srcVectorType) {
3733 elements[insertBeginPosition] = valueToStore;
3734 continue;
3735 }
3736
3737 SmallVector<Type> elementToInsertTypes(insertSize,
3738 srcVectorType.getElementType());
3739 // Get all elements from the vector in row-major order.
3740 auto elementsToInsert = vector::ToElementsOp::create(
3741 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3742 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3743 elements[insertBeginPosition + linearIdx] =
3744 elementsToInsert.getResult(linearIdx);
3745 }
3746 }
3747
3748 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3749 return success();
3750 }
3751};
3752
3753} // namespace
3754
3755static Attribute
3757 Attribute dstAttr,
3758 int64_t maxVectorSizeFoldThreshold) {
3759 if (insertOp.hasDynamicPosition())
3760 return {};
3761
3762 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3763 if (!denseDst)
3764 return {};
3765
3766 if (!srcAttr) {
3767 return {};
3768 }
3769
3770 VectorType destTy = insertOp.getDestVectorType();
3771 if (destTy.isScalable())
3772 return {};
3773
3774 // Make sure we do not create too many large constants.
3775 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3776 !insertOp->hasOneUse())
3777 return {};
3778
3779 // Calculate the linearized position for inserting elements.
3780 int64_t insertBeginPosition =
3781 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3782 SmallVector<Attribute> insertedValues;
3783 Type destEltType = destTy.getElementType();
3784
3785 /// Converts attribute to the expected type if there's
3786 /// a mismatch.
3787 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3788 for (auto value : denseSource.getValues<Attribute>())
3789 insertedValues.push_back(convertNumericAttr(value, destEltType));
3790 } else {
3791 insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3792 }
3793
3794 auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3795 copy(insertedValues, allValues.begin() + insertBeginPosition);
3796 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3797
3798 return newAttr;
3799}
3800
3801/// Folder to replace the `dest` operand of the insert op with the root dest of
3802/// the insert op use chain.
3803static Value foldInsertUseChain(InsertOp insertOp) {
3804 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3805 if (!destInsert)
3806 return {};
3807
3808 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3809 return {};
3810
3811 insertOp.setOperand(1, destInsert.getDest());
3812 return insertOp.getResult();
3813}
3814
3815void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3816 MLIRContext *context) {
3817 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3818 InsertChainFullyInitialized>(context);
3819}
3820
3821OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3822 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3823 // unless the source vector constant has a single use.
3824 constexpr int64_t vectorSizeFoldThreshold = 256;
3825 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3826 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3827 // (type mismatch).
3828 if (getNumIndices() == 0 && getValueToStoreType() == getType())
3829 return getValueToStore();
3830 // Fold `arith.constant` indices into the `vector.insert` operation.
3831 // Do not stop here as this fold may enable subsequent folds that require
3832 // constant indices.
3833 SmallVector<Value> operands = {getValueToStore(), getDest()};
3834 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3835
3836 if (auto res = foldInsertUseChain(*this))
3837 return res;
3838 if (auto res = foldPoisonIndexInsertExtractOp(
3839 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3840 return res;
3841 if (auto res = foldDenseElementsAttrDestInsertOp(
3842 *this, adaptor.getValueToStore(), adaptor.getDest(),
3843 vectorSizeFoldThreshold)) {
3844 return res;
3845 }
3846
3847 return inplaceFolded;
3848}
3849
3850//===----------------------------------------------------------------------===//
3851// InsertStridedSliceOp
3852//===----------------------------------------------------------------------===//
3853
3854void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3855 Value source, Value dest,
3856 ArrayRef<int64_t> offsets,
3857 ArrayRef<int64_t> strides) {
3858 result.addOperands({source, dest});
3859 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3860 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3861 result.addTypes(dest.getType());
3862 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3863 offsetsAttr);
3864 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3865 stridesAttr);
3866}
3867
3868// TODO: Should be moved to Tablegen ConfinedAttr attributes.
3869template <typename OpType>
3870static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3871 ArrayAttr arrayAttr,
3873 StringRef attrName) {
3874 if (arrayAttr.size() > shape.size())
3875 return op.emitOpError("expected ")
3876 << attrName << " attribute of rank no greater than vector rank";
3877 return success();
3878}
3879
3880// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3881// interval. If `halfOpen` is true then the admissible interval is [min, max).
3882// Otherwise, the admissible interval is [min, max].
3883template <typename OpType>
3884static LogicalResult
3886 int64_t max, StringRef attrName,
3887 bool halfOpen = true) {
3888 for (auto attr : arrayAttr) {
3889 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3890 auto upper = max;
3891 if (!halfOpen)
3892 upper += 1;
3893 if (val < min || val >= upper)
3894 return op.emitOpError("expected ") << attrName << " to be confined to ["
3895 << min << ", " << upper << ")";
3896 }
3897 return success();
3898}
3899
3900// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3901// interval. If `halfOpen` is true then the admissible interval is [min, max).
3902// Otherwise, the admissible interval is [min, max].
3903template <typename OpType>
3904static LogicalResult
3906 ArrayRef<int64_t> shape, StringRef attrName,
3907 bool halfOpen = true, int64_t min = 0) {
3908 for (auto [index, attrDimPair] :
3909 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3910 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3911 int64_t max = std::get<1>(attrDimPair);
3912 if (!halfOpen)
3913 max += 1;
3914 if (val < min || val >= max)
3915 return op.emitOpError("expected ")
3916 << attrName << " dimension " << index << " to be confined to ["
3917 << min << ", " << max << ")";
3918 }
3919 return success();
3920}
3921
3922// Returns true if, for all indices i = 0..shape.size()-1, val is in the
3923// [min, max} interval:
3924// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3925// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3926// the admissible interval is [min, max].
3927template <typename OpType>
3929 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3930 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3931 bool halfOpen = true, int64_t min = 1) {
3932 assert(arrayAttr1.size() <= shape.size());
3933 assert(arrayAttr2.size() <= shape.size());
3934 for (auto [index, it] :
3935 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3936 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3937 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3938 int64_t max = std::get<2>(it);
3939 if (!halfOpen)
3940 max += 1;
3941 if (val1 + val2 < 0 || val1 + val2 >= max)
3942 return op.emitOpError("expected sum(")
3943 << attrName1 << ", " << attrName2 << ") dimension " << index
3944 << " to be confined to [" << min << ", " << max << ")";
3945 }
3946 return success();
3947}
3948
3950 MLIRContext *context) {
3951 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3952 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3953 });
3954 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3955}
3956
3957LogicalResult InsertStridedSliceOp::verify() {
3958 auto sourceVectorType = getSourceVectorType();
3959 auto destVectorType = getDestVectorType();
3960 auto offsets = getOffsetsAttr();
3961 auto strides = getStridesAttr();
3962 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3963 return emitOpError(
3964 "expected offsets of same size as destination vector rank");
3965 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3966 return emitOpError("expected strides of same size as source vector rank");
3967 if (sourceVectorType.getRank() > destVectorType.getRank())
3968 return emitOpError(
3969 "expected source rank to be no greater than destination rank");
3970
3971 auto sourceShape = sourceVectorType.getShape();
3972 auto destShape = destVectorType.getShape();
3973 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3974 destShape.size() - sourceShape.size(), 0);
3975 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3976 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3977 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3978 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3979 offName)) ||
3980 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3981 /*max=*/1, stridesName,
3982 /*halfOpen=*/false)) ||
3984 *this, offsets,
3985 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3986 offName, "source vector shape",
3987 /*halfOpen=*/false, /*min=*/1)))
3988 return failure();
3989
3990 unsigned rankDiff = destShape.size() - sourceShape.size();
3991 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3992 if (sourceVectorType.getScalableDims()[idx] !=
3993 destVectorType.getScalableDims()[idx + rankDiff]) {
3994 return emitOpError("mismatching scalable flags (at source vector idx=")
3995 << idx << ")";
3996 }
3997 if (sourceVectorType.getScalableDims()[idx]) {
3998 auto sourceSize = sourceShape[idx];
3999 auto destSize = destShape[idx + rankDiff];
4000 if (sourceSize != destSize) {
4001 return emitOpError("expected size at idx=")
4002 << idx
4003 << (" to match the corresponding base size from the input "
4004 "vector (")
4005 << sourceSize << (" vs ") << destSize << (")");
4006 }
4007 }
4008 }
4009
4010 return success();
4011}
4012
4013namespace {
4014/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
4015class FoldInsertStridedSliceSplat final
4016 : public OpRewritePattern<InsertStridedSliceOp> {
4017public:
4018 using Base::Base;
4019
4020 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4021 PatternRewriter &rewriter) const override {
4022
4023 auto dst = insertStridedSliceOp.getDest();
4024 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4025 if (!splat || getScalarSplatSource(dst) != splat)
4026 return failure();
4027
4028 rewriter.replaceOp(insertStridedSliceOp, dst);
4029 return success();
4030 }
4031};
4032
4033/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
4034/// to dst.
4035class FoldInsertStridedSliceOfExtract final
4036 : public OpRewritePattern<InsertStridedSliceOp> {
4037public:
4038 using Base::Base;
4039
4040 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4041 PatternRewriter &rewriter) const override {
4042 auto extractStridedSliceOp =
4043 insertStridedSliceOp.getValueToStore()
4044 .getDefiningOp<vector::ExtractStridedSliceOp>();
4045
4046 if (!extractStridedSliceOp)
4047 return failure();
4048
4049 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4050 return failure();
4051
4052 // Check if have the same strides and offsets.
4053 if (extractStridedSliceOp.getStrides() !=
4054 insertStridedSliceOp.getStrides() ||
4055 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4056 return failure();
4057
4058 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4059 return success();
4060 }
4061};
4062
4063// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
4064// ConstantOp.
4065class InsertStridedSliceConstantFolder final
4066 : public OpRewritePattern<InsertStridedSliceOp> {
4067public:
4068 using Base::Base;
4069
4070 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
4071 // unless the source vector constant has a single use.
4072 static constexpr int64_t vectorSizeFoldThreshold = 256;
4073
4074 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4075 PatternRewriter &rewriter) const override {
4076 // Return if 'InsertOp' operand is not defined by a compatible vector
4077 // ConstantOp.
4078 TypedValue<VectorType> destVector = op.getDest();
4079 Attribute vectorDestCst;
4080 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
4081 return failure();
4082
4083 VectorType destTy = destVector.getType();
4084 if (destTy.isScalable())
4085 return failure();
4086
4087 // Make sure we do not create too many large constants.
4088 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4089 !destVector.hasOneUse())
4090 return failure();
4091
4092 TypedValue<VectorType> sourceValue = op.getValueToStore();
4093 Attribute sourceCst;
4094 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
4095 return failure();
4096
4097 // TODO: Support poison.
4098 if (matchPattern(vectorDestCst, ub::m_Poison()) ||
4099 matchPattern(sourceCst, ub::m_Poison()))
4100 return failure();
4101
4102 // TODO: Handle non-unit strides when they become available.
4103 if (op.hasNonUnitStrides())
4104 return failure();
4105
4106 VectorType sliceVecTy = sourceValue.getType();
4107 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4108 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4109 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
4110 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
4111
4112 // Calcualte the destination element indices by enumerating all slice
4113 // positions within the destination and linearizing them. The enumeration
4114 // order is lexicographic which yields a sequence of monotonically
4115 // increasing linearized position indices.
4116 // Because the destination may have higher dimensionality then the slice,
4117 // we keep track of two overlapping sets of positions and offsets.
4118 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4119 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4120 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4121 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4122 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4123 MutableArrayRef<int64_t> currSlicePosition(
4124 currDestPosition.begin() + rankDifference, currDestPosition.end());
4125 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4126 offsets.end());
4127 do {
4128 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
4129 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
4130 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4131 "Invalid slice element");
4132 newValues[linearizedPosition] = *sliceValuesIt;
4133 ++sliceValuesIt;
4134 } while (succeeded(
4135 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
4136
4137 auto newAttr = DenseElementsAttr::get(destTy, newValues);
4138 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
4139 return success();
4140 }
4141};
4142
4143} // namespace
4144
4145void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4146 RewritePatternSet &results, MLIRContext *context) {
4147 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4148 InsertStridedSliceConstantFolder>(context);
4149}
4150
4151OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4152 if (getSourceVectorType() == getDestVectorType())
4153 return getValueToStore();
4154 return {};
4155}
4156
4157//===----------------------------------------------------------------------===//
4158// OuterProductOp
4159//===----------------------------------------------------------------------===//
4160
4161/// Build an op without mask, use the type of `acc` as the return type.
4162void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4163 Value lhs, Value rhs, Value acc) {
4164 result.addOperands({lhs, rhs, acc});
4165 result.addTypes(acc.getType());
4166}
4167
4168void OuterProductOp::print(OpAsmPrinter &p) {
4169 p << " " << getLhs() << ", " << getRhs();
4170 if (getAcc()) {
4171 p << ", " << getAcc();
4172 p.printOptionalAttrDict((*this)->getAttrs());
4173 }
4174 p << " : " << getLhs().getType() << ", " << getRhs().getType();
4175}
4176
4177ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4178 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4179 Type tLHS, tRHS;
4180 if (parser.parseOperandList(operandsInfo) ||
4181 parser.parseOptionalAttrDict(result.attributes) ||
4182 parser.parseColonType(tLHS) || parser.parseComma() ||
4183 parser.parseType(tRHS))
4184 return failure();
4185 if (operandsInfo.size() < 2)
4186 return parser.emitError(parser.getNameLoc(),
4187 "expected at least 2 operands");
4188 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4189 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4190 if (!vLHS)
4191 return parser.emitError(parser.getNameLoc(),
4192 "expected vector type for operand #1");
4193
4194 VectorType resType;
4195 if (vRHS) {
4196 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4197 vRHS.getScalableDims()[0]};
4198 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4199 vLHS.getElementType(), scalableDimsRes);
4200 } else {
4201 // Scalar RHS operand
4202 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4203 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4204 scalableDimsRes);
4205 }
4206
4207 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4208 result.attributes.append(
4209 OuterProductOp::getKindAttrName(result.name),
4210 CombiningKindAttr::get(result.getContext(),
4211 OuterProductOp::getDefaultKind()));
4212 }
4213
4214 return failure(
4215 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4216 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4217 (operandsInfo.size() > 2 &&
4218 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4219 parser.addTypeToList(resType, result.types));
4220}
4221
4222LogicalResult OuterProductOp::verify() {
4223 Type tRHS = getOperandTypeRHS();
4224 VectorType vLHS = getOperandVectorTypeLHS(),
4225 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4226 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4227
4228 if (vLHS.getRank() != 1)
4229 return emitOpError("expected 1-d vector for operand #1");
4230
4231 if (vRHS) {
4232 // Proper OUTER operation.
4233 if (vRHS.getRank() != 1)
4234 return emitOpError("expected 1-d vector for operand #2");
4235 if (vRES.getRank() != 2)
4236 return emitOpError("expected 2-d vector result");
4237 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4238 return emitOpError("expected #1 operand dim to match result dim #1");
4239 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4240 return emitOpError("expected #2 operand dim to match result dim #2");
4241 if (vLHS.isScalable() && !vRHS.isScalable()) {
4242 // This restriction reflects what's currently supported in terms of
4243 // scalable vectors. However, we could relax this if there's a use case.
4244 return emitOpError(
4245 "expected either both or only #2 operand dim to be scalable");
4246 }
4247 } else {
4248 // An AXPY operation.
4249 if (vRES.getRank() != 1)
4250 return emitOpError("expected 1-d vector result");
4251 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4252 return emitOpError("expected #1 operand dim to match result dim #1");
4253 }
4254
4255 if (vACC && vACC != vRES)
4256 return emitOpError("expected operand #3 of same type as result type");
4257
4258 if (!getKindAttr()) {
4259 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
4260 "'vector.kind<add>')");
4261 }
4262
4263 // Verify supported combining kind.
4264 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4265 return emitOpError("unsupported outerproduct type");
4266
4267 return success();
4268}
4269
4270// MaskableOpInterface methods.
4271
4272/// Returns the mask type expected by this operation. Mostly used for
4273/// verification purposes. It requires the operation to be vectorized."
4274Type OuterProductOp::getExpectedMaskType() {
4275 auto vecType = this->getResultVectorType();
4276 return VectorType::get(vecType.getShape(),
4277 IntegerType::get(vecType.getContext(), /*width=*/1),
4278 vecType.getScalableDims());
4279}
4280
4281//===----------------------------------------------------------------------===//
4282// ExtractStridedSliceOp
4283//===----------------------------------------------------------------------===//
4284
4285// Inference works as follows:
4286// 1. Add 'sizes' from prefix of dims in 'offsets'.
4287// 2. Add sizes from 'vectorType' for remaining dims.
4288// Scalable flags are inherited from 'vectorType'.
4289static Type inferStridedSliceOpResultType(VectorType vectorType,
4290 ArrayAttr offsets, ArrayAttr sizes,
4291 ArrayAttr strides) {
4292 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4294 shape.reserve(vectorType.getRank());
4295 unsigned idx = 0;
4296 for (unsigned e = offsets.size(); idx < e; ++idx)
4297 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4298 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4299 shape.push_back(vectorType.getShape()[idx]);
4300
4301 return VectorType::get(shape, vectorType.getElementType(),
4302 vectorType.getScalableDims());
4303}
4304
4305void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4306 Value source, ArrayRef<int64_t> offsets,
4307 ArrayRef<int64_t> sizes,
4308 ArrayRef<int64_t> strides) {
4309 result.addOperands(source);
4310 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4311 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4312 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4313 result.addTypes(
4314 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4315 offsetsAttr, sizesAttr, stridesAttr));
4316 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4317 offsetsAttr);
4318 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4319 sizesAttr);
4320 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4321 stridesAttr);
4322}
4323
4324LogicalResult ExtractStridedSliceOp::verify() {
4325 auto type = getSourceVectorType();
4326 auto offsets = getOffsetsAttr();
4327 auto sizes = getSizesAttr();
4328 auto strides = getStridesAttr();
4329 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4330 return emitOpError(
4331 "expected offsets, sizes and strides attributes of same size");
4332
4333 auto shape = type.getShape();
4334 auto offName = getOffsetsAttrName();
4335 auto sizesName = getSizesAttrName();
4336 auto stridesName = getStridesAttrName();
4337 if (failed(
4338 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4339 failed(
4340 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4341 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4342 stridesName)) ||
4343 failed(
4344 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4345 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4346 /*halfOpen=*/false,
4347 /*min=*/1)) ||
4348 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4349 /*max=*/1, stridesName,
4350 /*halfOpen=*/false)) ||
4351 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4352 shape, offName, sizesName,
4353 /*halfOpen=*/false)))
4354 return failure();
4355
4356 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4357 offsets, sizes, strides);
4358 if (getResult().getType() != resultType)
4359 return emitOpError("expected result type to be ") << resultType;
4360
4361 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4362 if (type.getScalableDims()[idx]) {
4363 auto inputDim = type.getShape()[idx];
4364 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4365 if (inputDim != inputSize)
4366 return emitOpError("expected size at idx=")
4367 << idx
4368 << (" to match the corresponding base size from the input "
4369 "vector (")
4370 << inputSize << (" vs ") << inputDim << (")");
4371 }
4372 }
4373
4374 return success();
4375}
4376
4377// When the source of ExtractStrided comes from a chain of InsertStrided ops try
4378// to use the source of the InsertStrided ops if we can detect that the
4379// extracted vector is a subset of one of the vector inserted.
4380static LogicalResult
4381foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4382 // Helper to extract integer out of ArrayAttr.
4383 auto getElement = [](ArrayAttr array, int idx) {
4384 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4385 };
4386 ArrayAttr extractOffsets = op.getOffsets();
4387 ArrayAttr extractStrides = op.getStrides();
4388 ArrayAttr extractSizes = op.getSizes();
4389 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4390 while (insertOp) {
4391 if (op.getSourceVectorType().getRank() !=
4392 insertOp.getSourceVectorType().getRank())
4393 return failure();
4394 ArrayAttr insertOffsets = insertOp.getOffsets();
4395 ArrayAttr insertStrides = insertOp.getStrides();
4396 // If the rank of extract is greater than the rank of insert, we are likely
4397 // extracting a partial chunk of the vector inserted.
4398 if (extractOffsets.size() > insertOffsets.size())
4399 return failure();
4400 bool patialoverlap = false;
4401 bool disjoint = false;
4402 SmallVector<int64_t, 4> offsetDiffs;
4403 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4404 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4405 return failure();
4406 int64_t start = getElement(insertOffsets, dim);
4407 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4408 int64_t offset = getElement(extractOffsets, dim);
4409 int64_t size = getElement(extractSizes, dim);
4410 // Check if the start of the extract offset is in the interval inserted.
4411 if (start <= offset && offset < end) {
4412 // If the extract interval overlaps but is not fully included we may
4413 // have a partial overlap that will prevent any folding.
4414 if (offset + size > end)
4415 patialoverlap = true;
4416 offsetDiffs.push_back(offset - start);
4417 continue;
4418 }
4419 disjoint = true;
4420 break;
4421 }
4422 // The extract element chunk is a subset of the insert element.
4423 if (!disjoint && !patialoverlap) {
4424 op.setOperand(insertOp.getValueToStore());
4425 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4426 OpBuilder b(op.getContext());
4427 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4428 return success();
4429 }
4430 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4431 // in the insert chain.
4432 if (disjoint)
4433 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4434 else {
4435 // The extracted vector partially overlap the inserted vector, we cannot
4436 // fold.
4437 return failure();
4438 }
4439 }
4440 return failure();
4441}
4442
4443// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4444static OpFoldResult
4446 Attribute foldInput) {
4447
4448 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4449 if (!dense)
4450 return {};
4451
4452 // TODO: Handle non-unit strides when they become available.
4453 if (op.hasNonUnitStrides())
4454 return {};
4455
4456 VectorType sourceVecTy = op.getSourceVectorType();
4457 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4458 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4459
4460 VectorType sliceVecTy = op.getType();
4461 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4462 int64_t rank = sliceVecTy.getRank();
4463
4464 // Expand offsets and sizes to match the vector rank.
4465 SmallVector<int64_t, 4> offsets(rank, 0);
4466 copy(getI64SubArray(op.getOffsets()), offsets.begin());
4467
4468 SmallVector<int64_t, 4> sizes(sourceShape);
4469 copy(getI64SubArray(op.getSizes()), sizes.begin());
4470
4471 // Calculate the slice elements by enumerating all slice positions and
4472 // linearizing them. The enumeration order is lexicographic which yields a
4473 // sequence of monotonically increasing linearized position indices.
4474 const auto denseValuesBegin = dense.value_begin<Attribute>();
4475 SmallVector<Attribute> sliceValues;
4476 sliceValues.reserve(sliceVecTy.getNumElements());
4477 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4478 do {
4479 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4480 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4481 "Invalid index");
4482 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4483 } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4484
4485 assert(static_cast<int64_t>(sliceValues.size()) ==
4486 sliceVecTy.getNumElements() &&
4487 "Invalid number of slice elements");
4488 return DenseElementsAttr::get(sliceVecTy, sliceValues);
4489}
4490
4491OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4492 if (getSourceVectorType() == getResult().getType())
4493 return getSource();
4494 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4495 return getResult();
4496
4497 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4498 if (auto splat =
4499 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4500 return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4501
4502 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4503 return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4504}
4505
4506void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4507 populateFromInt64AttrArray(getOffsets(), results);
4508}
4509
4510namespace {
4511
4512// Pattern to rewrite nested ExtractStridedSliceOp into a single one.
4513//
4514// Example:
4515//
4516// %0 = vector.extract_strided_slice %arg0
4517// {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]}
4518// : vector<4x8x16xf32> to vector<3x4x16xf32>
4519// %1 = vector.extract_strided_slice %0
4520// {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
4521// : vector<3x4x16xf32> to vector<2x2x16xf32>
4522//
4523// to
4524//
4525// %1 = vector.extract_strided_slice %arg0
4526// {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
4527// : vector<4x8x16xf32> to vector<2x2x16xf32>
4528class StridedSliceFolder final
4529 : public OpRewritePattern<ExtractStridedSliceOp> {
4530public:
4531 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4532
4533 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4534 PatternRewriter &rewriter) const override {
4535 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4536 if (!firstOp)
4537 return failure();
4538
4539 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4540 return failure();
4541
4542 SmallVector<int64_t> firstOffsets = getI64SubArray(firstOp.getOffsets());
4543 SmallVector<int64_t> firstSizes = getI64SubArray(firstOp.getSizes());
4544 SmallVector<int64_t> secondOffsets = getI64SubArray(secondOp.getOffsets());
4545 SmallVector<int64_t> secondSizes = getI64SubArray(secondOp.getSizes());
4546
4547 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4548 SmallVector<int64_t> combinedOffsets(newRank, 0);
4549 SmallVector<int64_t> combinedSizes(newRank);
4550 ArrayRef<int64_t> firstSourceShape =
4551 firstOp.getSourceVectorType().getShape();
4552 for (unsigned i = 0; i < newRank; ++i) {
4553 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4554 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4555 combinedOffsets[i] = off1 + off2;
4556
4557 if (i < secondSizes.size()) {
4558 combinedSizes[i] = secondSizes[i];
4559 } else if (i < firstSizes.size()) {
4560 combinedSizes[i] = firstSizes[i];
4561 } else {
4562 combinedSizes[i] = firstSourceShape[i];
4563 }
4564 }
4565
4566 SmallVector<int64_t> combinedStrides(newRank, 1);
4567 rewriter.replaceOpWithNewOp<ExtractStridedSliceOp>(
4568 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4569 combinedStrides);
4570 return success();
4571 }
4572};
4573
4574// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4575// CreateMaskOp.
4576//
4577// Example:
4578//
4579// %mask = vector.create_mask %ub : vector<16xi1>
4580// %slice = vector.extract_strided_slice [%offset] [8] [1]
4581//
4582// to
4583//
4584// %new_ub = arith.subi %ub, %offset
4585// %mask = vector.create_mask %new_ub : vector<8xi1>
4586class StridedSliceCreateMaskFolder final
4587 : public OpRewritePattern<ExtractStridedSliceOp> {
4588 using Base::Base;
4589
4590public:
4591 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4592 PatternRewriter &rewriter) const override {
4593 Location loc = extractStridedSliceOp.getLoc();
4594 // Return if 'extractStridedSliceOp' operand is not defined by a
4595 // CreateMaskOp.
4596 auto createMaskOp =
4597 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4598 if (!createMaskOp)
4599 return failure();
4600 // Return if 'extractStridedSliceOp' has non-unit strides.
4601 if (extractStridedSliceOp.hasNonUnitStrides())
4602 return failure();
4603 // Gather constant mask dimension sizes.
4604 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4605 // Gather strided slice offsets and sizes.
4606 SmallVector<int64_t> sliceOffsets;
4607 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4608 sliceOffsets);
4609 SmallVector<int64_t> sliceSizes;
4610 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4611
4612 // Compute slice of vector mask region.
4613 SmallVector<Value> sliceMaskDimSizes;
4614 sliceMaskDimSizes.reserve(maskDimSizes.size());
4615 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4616 // only iterate on the leading dim sizes. The tail accounts for the
4617 // remaining dim sizes.
4618 for (auto [maskDimSize, sliceOffset, sliceSize] :
4619 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4620 // No need to clamp on min/max values, because create_mask has clamping
4621 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4622 // greater than the vector dim size.
4623 IntegerAttr offsetAttr =
4624 rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4625 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4626 Value sliceMaskDimSize =
4627 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4628 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4629 }
4630 // Add unchanged dimensions.
4631 llvm::append_range(
4632 sliceMaskDimSizes,
4633 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4634 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4635 // region.
4636 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4637 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4638 sliceMaskDimSizes);
4639 return success();
4640 }
4641};
4642
4643// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4644// ConstantMaskOp.
4645class StridedSliceConstantMaskFolder final
4646 : public OpRewritePattern<ExtractStridedSliceOp> {
4647public:
4648 using Base::Base;
4649
4650 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4651 PatternRewriter &rewriter) const override {
4652 // Return if 'extractStridedSliceOp' operand is not defined by a
4653 // ConstantMaskOp.
4654 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4655 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4656 if (!constantMaskOp)
4657 return failure();
4658 // Return if 'extractStridedSliceOp' has non-unit strides.
4659 if (extractStridedSliceOp.hasNonUnitStrides())
4660 return failure();
4661 // Gather constant mask dimension sizes.
4662 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4663 // Gather strided slice offsets and sizes.
4664 SmallVector<int64_t> sliceOffsets;
4665 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4666 sliceOffsets);
4667 SmallVector<int64_t> sliceSizes;
4668 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4669
4670 // Compute slice of vector mask region.
4671 SmallVector<int64_t> sliceMaskDimSizes;
4672 sliceMaskDimSizes.reserve(maskDimSizes.size());
4673 for (auto [maskDimSize, sliceOffset, sliceSize] :
4674 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4675 int64_t sliceMaskDimSize = std::max(
4676 static_cast<int64_t>(0),
4677 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4678 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4679 }
4680 // Add unchanged dimensions.
4681 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4682 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4683 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4684 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4685 // region is a conjunction of mask dim intervals).
4686 if (llvm::is_contained(sliceMaskDimSizes, 0))
4687 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4688
4689 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4690 // region.
4691 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4692 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4693 sliceMaskDimSizes);
4694 return success();
4695 }
4696};
4697
4698// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4699// BroadcastOp(ExtractStrideSliceOp).
4700class StridedSliceBroadcast final
4701 : public OpRewritePattern<ExtractStridedSliceOp> {
4702public:
4703 using Base::Base;
4704
4705 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4706 PatternRewriter &rewriter) const override {
4707 auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4708 if (!broadcast)
4709 return failure();
4710 auto srcVecType =
4711 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4712 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4713 auto dstVecType = llvm::cast<VectorType>(op.getType());
4714 unsigned dstRank = dstVecType.getRank();
4715 unsigned rankDiff = dstRank - srcRank;
4716 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4717 // (n -> m with n > m). If they are originally both broadcasted *and*
4718 // sliced, this can be simplified to just broadcasting.
4719 bool needsSlice = false;
4720 for (unsigned i = 0; i < srcRank; i++) {
4721 if (srcVecType.getDimSize(i) != 1 &&
4722 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4723 needsSlice = true;
4724 break;
4725 }
4726 }
4727 Value source = broadcast.getSource();
4728 if (needsSlice) {
4729 SmallVector<int64_t> offsets =
4730 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4731 SmallVector<int64_t> sizes =
4732 getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4733 for (unsigned i = 0; i < srcRank; i++) {
4734 if (srcVecType.getDimSize(i) == 1) {
4735 // In case this dimension was broadcasted *and* sliced, the offset
4736 // and size need to be updated now that there is no broadcast before
4737 // the slice.
4738 offsets[i] = 0;
4739 sizes[i] = 1;
4740 }
4741 }
4742 source = ExtractStridedSliceOp::create(
4743 rewriter, op->getLoc(), source, offsets, sizes,
4744 getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4745 }
4746 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4747 return success();
4748 }
4749};
4750
4751/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4752class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4753public:
4754 using Base::Base;
4755
4756 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4757 PatternRewriter &rewriter) const override {
4758
4759 Value splat = getScalarSplatSource(op.getSource());
4760 if (!splat)
4761 return failure();
4762 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4763 return success();
4764 }
4765};
4766
4767/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4768/// slice is contiguous, into extract and shape_cast.
4769///
4770/// Example:
4771/// Before:
4772/// %1 = vector.extract_strided_slice %arg0 {
4773/// offsets = [0, 0, 0, 0, 0],
4774/// sizes = [1, 1, 1, 1, 8],
4775/// strides = [1, 1, 1, 1, 1]
4776/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4777/// After:
4778/// %0 = vector.extract %arg0[0, 0, 0, 0]
4779/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4780/// %1 = vector.shape_cast %0
4781/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4782///
4783class ContiguousExtractStridedSliceToExtract final
4784 : public OpRewritePattern<ExtractStridedSliceOp> {
4785public:
4786 using Base::Base;
4787
4788 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4789 PatternRewriter &rewriter) const override {
4790 if (op.hasNonUnitStrides())
4791 return failure();
4792 Value source = op.getOperand();
4793 auto sourceType = cast<VectorType>(source.getType());
4794 if (sourceType.isScalable() || sourceType.getRank() == 0)
4795 return failure();
4796
4797 // Compute the number of offsets to pass to ExtractOp::build. That is the
4798 // difference between the source rank and the desired slice rank. We walk
4799 // the dimensions from innermost out, and stop when the next slice dimension
4800 // is not full-size.
4801 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4802 int numOffsets;
4803 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4804 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4805 break;
4806 }
4807
4808 // If the created extract op would have no offsets, then this whole
4809 // extract_strided_slice is the identity and should have been handled by
4810 // other canonicalizations.
4811 if (numOffsets == 0)
4812 return failure();
4813
4814 // If not even the inner-most dimension is full-size, this op can't be
4815 // rewritten as an ExtractOp.
4816 if (numOffsets == sourceType.getRank() &&
4817 static_cast<int>(sizes.size()) == sourceType.getRank())
4818 return failure();
4819
4820 // The outer dimensions must have unit size.
4821 for (int i = 0; i < numOffsets; ++i) {
4822 if (sizes[i] != 1)
4823 return failure();
4824 }
4825
4826 // Avoid generating slices that have leading unit dimensions. The shape_cast
4827 // op that we create below would take bad generic fallback patterns
4828 // (ShapeCastOpRewritePattern).
4829 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4830 sizes[numOffsets] == 1) {
4831 ++numOffsets;
4832 }
4833
4834 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4835 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4836 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4837 extractOffsets);
4838 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4839 return success();
4840 }
4841};
4842
4843} // namespace
4844
4845void ExtractStridedSliceOp::getCanonicalizationPatterns(
4846 RewritePatternSet &results, MLIRContext *context) {
4847 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4848 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4849 results.add<StridedSliceFolder, StridedSliceCreateMaskFolder,
4850 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4851 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4852 context);
4853}
4854
4855//===----------------------------------------------------------------------===//
4856// TransferReadOp
4857//===----------------------------------------------------------------------===//
4858
4859/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4860void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4861 VectorType vectorType, Value source,
4862 ValueRange indices, std::optional<Value> padding,
4863 AffineMapAttr permutationMapAttr,
4864 /*optional*/ ArrayAttr inBoundsAttr) {
4865
4866 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4867 if (!padding)
4868 padding = ub::PoisonOp::create(builder, result.location, elemType);
4869 build(builder, result, vectorType, source, indices, permutationMapAttr,
4870 *padding, /*mask=*/Value(), inBoundsAttr);
4871}
4872
4873/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4874void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4875 VectorType vectorType, Value source,
4876 ValueRange indices, std::optional<Value> padding,
4877 AffineMap permutationMap,
4878 std::optional<ArrayRef<bool>> inBounds) {
4879 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4880 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4881 ? builder.getBoolArrayAttr(inBounds.value())
4882 : builder.getBoolArrayAttr(
4883 SmallVector<bool>(vectorType.getRank(), false));
4884 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4885 if (!padding)
4886 padding = ub::PoisonOp::create(builder, result.location, elemType);
4887 build(builder, result, vectorType, source, indices, *padding,
4888 permutationMapAttr, inBoundsAttr);
4889}
4890
4891/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4892void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4893 VectorType vectorType, Value source,
4894 ValueRange indices, std::optional<Value> padding,
4895 std::optional<ArrayRef<bool>> inBounds) {
4896 AffineMap permutationMap = getTransferMinorIdentityMap(
4897 llvm::cast<ShapedType>(source.getType()), vectorType);
4898 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4899 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4900 ? builder.getBoolArrayAttr(inBounds.value())
4901 : builder.getBoolArrayAttr(
4902 SmallVector<bool>(vectorType.getRank(), false));
4903 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4904 if (!padding)
4905 padding = ub::PoisonOp::create(builder, result.location, elemType);
4906 build(builder, result, vectorType, source, indices, permutationMapAttr,
4907 *padding,
4908 /*mask=*/Value(), inBoundsAttr);
4909}
4910
4911template <typename EmitFun>
4912static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4913 EmitFun emitOpError) {
4914 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4915 for (auto expr : permutationMap.getResults()) {
4916 auto dim = dyn_cast<AffineDimExpr>(expr);
4917 auto zero = dyn_cast<AffineConstantExpr>(expr);
4918 if (zero) {
4919 if (zero.getValue() != 0) {
4920 return emitOpError(
4921 "requires a projected permutation_map (at most one dim or the zero "
4922 "constant can appear in each result)");
4923 }
4924 continue;
4925 }
4926 if (!dim) {
4927 return emitOpError("requires a projected permutation_map (at most one "
4928 "dim or the zero constant can appear in each result)");
4929 }
4930 if (seen[dim.getPosition()]) {
4931 return emitOpError(
4932 "requires a permutation_map that is a permutation (found one dim "
4933 "used more than once)");
4934 }
4935 seen[dim.getPosition()] = true;
4936 }
4937 return success();
4938}
4939
4940static LogicalResult
4941verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4942 VectorType vectorType, VectorType maskType,
4943 VectorType inferredMaskType, AffineMap permutationMap,
4944 ArrayAttr inBounds) {
4945 if (op->hasAttr("masked")) {
4946 return op->emitOpError("masked attribute has been removed. "
4947 "Use in_bounds instead.");
4948 }
4949
4950 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4951 return op->emitOpError(
4952 "requires source to be a memref or ranked tensor type");
4953
4954 auto elementType = shapedType.getElementType();
4955 DataLayout dataLayout = DataLayout::closest(op);
4956 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4957 // Memref or tensor has vector element type.
4958 unsigned sourceVecSize =
4959 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4960 vectorElementType.getShape().back();
4961 unsigned resultVecSize =
4962 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4963 vectorType.getShape().back();
4964 if (resultVecSize % sourceVecSize != 0)
4965 return op->emitOpError(
4966 "requires the bitwidth of the minor 1-D vector to be an integral "
4967 "multiple of the bitwidth of the minor 1-D vector of the source");
4968
4969 unsigned sourceVecEltRank = vectorElementType.getRank();
4970 unsigned resultVecRank = vectorType.getRank();
4971 if (sourceVecEltRank > resultVecRank)
4972 return op->emitOpError(
4973 "requires source vector element and vector result ranks to match.");
4974 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4975 // Check that permutation map results match 'rankOffset' of vector type.
4976 if (permutationMap.getNumResults() != rankOffset)
4977 return op->emitOpError("requires a permutation_map with result dims of "
4978 "the same rank as the vector type");
4979
4980 if (maskType)
4981 return op->emitOpError("does not support masks with vector element type");
4982 } else {
4983 // Memref or tensor has scalar element type.
4984 unsigned minorSize =
4985 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4986 unsigned resultVecSize =
4987 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4988 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4989 return op->emitOpError(
4990 "requires the bitwidth of the minor 1-D vector to be an integral "
4991 "multiple of the bitwidth of the source element type");
4992
4993 // Check that permutation map results match rank of vector type.
4994 if (permutationMap.getNumResults() != vectorType.getRank())
4995 return op->emitOpError("requires a permutation_map with result dims of "
4996 "the same rank as the vector type");
4997 }
4998
4999 if (permutationMap.getNumSymbols() != 0)
5000 return op->emitOpError("requires permutation_map without symbols");
5001
5002 if (permutationMap.getNumInputs() != shapedType.getRank())
5003 return op->emitOpError("requires a permutation_map with input dims of the "
5004 "same rank as the source type");
5005
5006 if (maskType && maskType != inferredMaskType)
5007 return op->emitOpError("inferred mask type (")
5008 << inferredMaskType << ") and mask operand type (" << maskType
5009 << ") don't match";
5010
5011 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
5012 return op->emitOpError("expects the in_bounds attr of same rank "
5013 "as permutation_map results: ")
5014 << AffineMapAttr::get(permutationMap)
5015 << " vs inBounds of size: " << inBounds.size();
5016
5017 return success();
5018}
5019
5020static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
5021 SmallVector<StringRef, 3> elidedAttrs;
5022 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5023 if (op.getPermutationMap().isMinorIdentity())
5024 elidedAttrs.push_back(op.getPermutationMapAttrName());
5025 // Elide in_bounds attribute if all dims are out-of-bounds.
5026 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
5027 elidedAttrs.push_back(op.getInBoundsAttrName());
5028 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5029}
5030
5031void TransferReadOp::print(OpAsmPrinter &p) {
5032 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
5033 if (getMask())
5034 p << ", " << getMask();
5035 printTransferAttrs(p, *this);
5036 p << " : " << getShapedType() << ", " << getVectorType();
5037}
5038
5039VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
5040 AffineMap permMap) {
5041 auto i1Type = IntegerType::get(permMap.getContext(), 1);
5042 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
5043 assert(invPermMap && "Inversed permutation map couldn't be computed");
5044 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
5045
5046 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
5047 // 0-D mask into a single-element 1-D mask.
5048 if (maskShape.empty())
5049 maskShape.push_back(1);
5050
5051 SmallVector<bool> scalableDims =
5052 applyPermutationMap(invPermMap, vecType.getScalableDims());
5053
5054 return VectorType::get(maskShape, i1Type, scalableDims);
5055}
5056
5057ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
5058 auto &builder = parser.getBuilder();
5059 SMLoc typesLoc;
5065 // Parsing with support for paddingValue.
5066 if (parser.parseOperand(sourceInfo) ||
5068 parser.parseComma() || parser.parseOperand(paddingInfo))
5069 return failure();
5070 ParseResult hasMask = parser.parseOptionalComma();
5071 if (hasMask.succeeded()) {
5072 if (parser.parseOperand(maskInfo))
5073 return failure();
5074 }
5075 if (parser.parseOptionalAttrDict(result.attributes) ||
5076 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5077 return failure();
5078 if (types.size() != 2)
5079 return parser.emitError(typesLoc, "requires two types");
5080 auto indexType = builder.getIndexType();
5081 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5082 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5083 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5084 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5085 if (!vectorType)
5086 return parser.emitError(typesLoc, "requires vector type");
5087 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
5088 Attribute permMapAttr = result.attributes.get(permMapAttrName);
5089 AffineMap permMap;
5090 if (!permMapAttr) {
5091 if (shapedType.getRank() <
5092 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5093 return parser.emitError(typesLoc,
5094 "expected a custom permutation_map when "
5095 "rank(source) != rank(destination)");
5096 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5097 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5098 } else {
5099 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5100 }
5101 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
5102 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5103 if (!inBoundsAttr) {
5104 result.addAttribute(inBoundsAttrName,
5105 builder.getBoolArrayAttr(
5106 SmallVector<bool>(permMap.getNumResults(), false)));
5107 }
5108 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5109 parser.resolveOperands(indexInfo, indexType, result.operands) ||
5110 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
5111 result.operands))
5112 return failure();
5113 if (hasMask.succeeded()) {
5114 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5115 return parser.emitError(
5116 maskInfo.location, "does not support masks with vector element type");
5117 if (vectorType.getRank() != permMap.getNumResults()) {
5118 return parser.emitError(typesLoc,
5119 "expected the same rank for the vector and the "
5120 "results of the permutation map");
5121 }
5122 // Instead of adding the mask type as an op type, compute it based on the
5123 // vector type and the permutation map (to keep the type signature small).
5124 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5125 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5126 return failure();
5127 }
5128 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5129 builder.getDenseI32ArrayAttr(
5130 {1, static_cast<int32_t>(indexInfo.size()), 1,
5131 static_cast<int32_t>(hasMask.succeeded())}));
5132 return parser.addTypeToList(vectorType, result.types);
5133}
5134
5135LogicalResult TransferReadOp::verify() {
5136 // Consistency of elemental types in source and vector.
5137 ShapedType shapedType = getShapedType();
5138 VectorType vectorType = getVectorType();
5139 VectorType maskType = getMaskType();
5140 auto paddingType = getPadding().getType();
5141 auto permutationMap = getPermutationMap();
5142 VectorType inferredMaskType =
5143 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5144 : VectorType();
5145 auto sourceElementType = shapedType.getElementType();
5146
5147 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
5148 return emitOpError("requires ") << shapedType.getRank() << " indices";
5149
5150 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5151 shapedType, vectorType, maskType,
5152 inferredMaskType, permutationMap, getInBounds())))
5153 return failure();
5154
5155 if (auto sourceVectorElementType =
5156 llvm::dyn_cast<VectorType>(sourceElementType)) {
5157 // Source has vector element type.
5158 // Check that 'sourceVectorElementType' and 'paddingType' types match.
5159 if (sourceVectorElementType != paddingType)
5160 return emitOpError(
5161 "requires source element type and padding type to match.");
5162
5163 } else {
5164 // Check that 'paddingType' is valid to store in a vector type.
5165 if (!VectorType::isValidElementType(paddingType))
5166 return emitOpError("requires valid padding vector elemental type");
5167
5168 // Check that padding type and vector element types match.
5169 if (paddingType != sourceElementType)
5170 return emitOpError(
5171 "requires formal padding and source of the same elemental type");
5172 }
5173
5174 return verifyPermutationMap(permutationMap,
5175 [&](Twine t) { return emitOpError(t); });
5176}
5177
5178// MaskableOpInterface methods.
5179
5180/// Returns the mask type expected by this operation. Mostly used for
5181/// verification purposes. It requires the operation to be vectorized."
5182Type TransferReadOp::getExpectedMaskType() {
5183 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5184}
5185
5186//===----------------------------------------------------------------------===//
5187// TransferReadOp: VectorTransferOpInterface methods.
5188//===----------------------------------------------------------------------===//
5189VectorType TransferReadOp::getVectorType() {
5190 return cast<VectorType>(getVector().getType());
5191}
5192
5193template <typename TransferOp>
5194static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
5195 // TODO: support more aggressive createOrFold on:
5196 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
5197 if (op.getShapedType().isDynamicDim(indicesIdx))
5198 return false;
5199 Value index = op.getIndices()[indicesIdx];
5200 std::optional<int64_t> cstOp = getConstantIntValue(index);
5201 if (!cstOp.has_value())
5202 return false;
5203
5204 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5205 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5206
5207 return cstOp.value() + vectorSize <= sourceSize;
5208}
5209
5210template <typename TransferOp>
5211static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
5212 // TODO: support 0-d corner case.
5213 // TODO: Be less conservative.
5214 if (op.getTransferRank() == 0)
5215 return failure();
5216 AffineMap permutationMap = op.getPermutationMap();
5217 bool changed = false;
5218 SmallVector<bool, 4> newInBounds;
5219 newInBounds.reserve(op.getTransferRank());
5220 // Idxs of non-bcast dims - used when analysing bcast dims.
5221 SmallVector<unsigned> nonBcastDims;
5222
5223 // 1. Process non-broadcast dims
5224 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5225 // 1.1. Already marked as in-bounds, nothing to see here.
5226 if (op.isDimInBounds(i)) {
5227 newInBounds.push_back(true);
5228 continue;
5229 }
5230 // 1.2. Currently out-of-bounds, check whether we can statically determine
5231 // it is inBounds.
5232 bool inBounds = false;
5233 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5234 if (dimExpr) {
5235 inBounds = isInBounds(op, /*resultIdx=*/i,
5236 /*indicesIdx=*/dimExpr.getPosition());
5237 nonBcastDims.push_back(i);
5238 }
5239
5240 newInBounds.push_back(inBounds);
5241 // We commit the pattern if it is "more inbounds".
5242 changed |= inBounds;
5243 }
5244
5245 // 2. Handle broadcast dims
5246 // If all non-broadcast dims are "in bounds", then all bcast dims should be
5247 // "in bounds" as well.
5248 bool allNonBcastDimsInBounds = llvm::all_of(
5249 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5250 if (allNonBcastDimsInBounds) {
5251 for (size_t idx : permutationMap.getBroadcastDims()) {
5252 changed |= !newInBounds[idx];
5253 newInBounds[idx] = true;
5254 }
5255 }
5256
5257 if (!changed)
5258 return failure();
5259 // OpBuilder is only used as a helper to build an I64ArrayAttr.
5260 OpBuilder b(op.getContext());
5261 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5262 return success();
5263}
5264
5265template <typename TransferOp>
5266static LogicalResult foldTransferFullMask(TransferOp op) {
5267 auto mask = op.getMask();
5268 if (!mask)
5269 return failure();
5270
5272 return failure();
5273
5274 op.getMaskMutable().clear();
5275 return success();
5276}
5277
5278/// ```
5279/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5280/// : vector<1x4xf32>, tensor<4x4xf32>
5281/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5282/// : tensor<4x4xf32>, vector<1x4xf32>
5283/// ```
5284/// -> Folds into
5285/// ```
5286/// %v0
5287/// ```
5288static Value foldRAW(TransferReadOp readOp) {
5289 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5290 return {};
5291 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5292 while (defWrite) {
5293 if (checkSameValueRAW(defWrite, readOp))
5294 return defWrite.getVector();
5296 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5297 cast<VectorTransferOpInterface>(readOp.getOperation())))
5298 break;
5299 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5300 }
5301 return {};
5302}
5303
5304OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5305 if (Value vec = foldRAW(*this))
5306 return vec;
5307 /// transfer_read(memrefcast) -> transfer_read
5308 if (succeeded(foldTransferInBoundsAttribute(*this)))
5309 return getResult();
5310 if (succeeded(foldTransferFullMask(*this)))
5311 return getResult();
5312 if (succeeded(memref::foldMemRefCast(*this)))
5313 return getResult();
5314 if (succeeded(tensor::foldTensorCast(*this)))
5315 return getResult();
5316 return OpFoldResult();
5317}
5318
5319std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5320 return llvm::to_vector<4>(getVectorType().getShape());
5321}
5322
5323void TransferReadOp::getEffects(
5324 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5325 &effects) {
5326 if (llvm::isa<MemRefType>(getShapedType()))
5327 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5328 SideEffects::DefaultResource::get());
5329}
5330
5331Speculation::Speculatability TransferReadOp::getSpeculatability() {
5332 if (hasPureTensorSemantics())
5335}
5336
5337/// Given a projected permutation, inverse an affine map, making the unused dims
5338/// 0 in the result.
5339static AffineMap inverseWithUnusedDims(AffineMap map) {
5340 assert(map.isProjectedPermutation() &&
5341 "expected a projected permutation map");
5342 SmallVector<AffineExpr> results(map.getNumInputs(),
5344 for (auto [idx, result] : llvm::enumerate(map.getResults())) {
5345 // We should only have dim exprs because this is a projected permutation.
5346 int64_t pos = cast<AffineDimExpr>(result).getPosition();
5347 results[pos] = getAffineDimExpr(idx, map.getContext());
5348 }
5349 return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
5350 results, map.getContext());
5351}
5352
5353namespace {
5354/// Store to load forwarding for transfer operations with permuation maps.
5355/// Even if the permutation maps are different we can still propagate the store
5356/// into the load if the size of the dimensions read and written match. Then we
5357/// can replace the transfer_read + transfer_write by vector.broadcast and
5358/// vector.transpose.
5359/// Example:
5360/// ```
5361/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5362/// {in_bounds = [true, true],
5363/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5364/// vector<4x1xf32>, tensor<4x4x4xf32>
5365/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5366/// {in_bounds = [true, true, true, true],
5367/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5368/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5369/// ```
5370/// To:
5371/// ```
5372/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5373/// %r = vector.transpose %0, [3, 0, 2, 1] :
5374/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5375/// ```
5376struct TransferReadAfterWriteToBroadcast
5377 : public OpRewritePattern<TransferReadOp> {
5378 using Base::Base;
5379
5380 LogicalResult matchAndRewrite(TransferReadOp readOp,
5381 PatternRewriter &rewriter) const override {
5382 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5383 if (!defWrite)
5384 return failure();
5385 // Bail if we need an alias analysis.
5386 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5387 return failure();
5388 // Bail in the masked case (too complex atm and needed to properly account
5389 // for padding).
5390 if (readOp.getMask() || defWrite.getMask())
5391 return failure();
5392 // If indices are not the same a shift may be required, bail.
5393 if (readOp.getIndices() != defWrite.getIndices())
5394 return failure();
5395 // Bail if we need a bounds analysis.
5396 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5397 return failure();
5398 // TODO: If the written transfer chunk is a superset of the read transfer
5399 // chunk we could do an extract_strided_slice.
5400 if (readOp.getTransferChunkAccessed() !=
5401 defWrite.getTransferChunkAccessed())
5402 return failure();
5403 // WriteMap: tensor -> w_vec
5404 // ReadMap: tensor -> r_vec
5405 //
5406 // inv(WriteMap): w_vec -> tensor
5407 // inv(WriteMap) o ReadMap: w_vec -> r_vec
5408 AffineMap readMap = readOp.getPermutationMap();
5409 AffineMap writeMap = defWrite.getPermutationMap();
5410 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5411 AffineMap composedMap = readMap.compose(invWriteMap);
5412 // If there are any unused dims in the composedMap, we have to drop some
5413 // unit dims from the written vector before we can do transpose(broadcast).
5414 // TODO: Support this case.
5415 if (getUnusedDimsBitVector(composedMap).any())
5416 return failure();
5417 // readVec = transpose(broadcast(writeVec))
5418 //
5419 // Build a transpose permutation for the above transpose operation.
5420 //
5421 // Treat the composed map as having extra leading dimensions which are
5422 // the broadcasted dimensions, and treat the zeros as these new broadcasted
5423 // dimensions.
5424 SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
5425 int64_t numBroadcastedDims = broadcastedDims.size();
5426 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5427 invPerm.resize(composedMap.getNumResults());
5428 for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
5429 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5430 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5431 invPerm[effectiveDim] = idx;
5432 }
5433 }
5434 // Applying the inverse permutation on the readVecTy will give us the
5435 // broadcast result type.
5436 VectorType readVecTy = readOp.getVectorType();
5437 SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
5438 auto broadcastedVecTy =
5439 VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
5440 readVecTy.getElementType(),
5441 applyPermutation(readVecTy.getScalableDims(), invPerm));
5442 // Build the transpose(broadcast) transformation.
5443 Value vec = defWrite.getVector();
5444 Location loc = readOp.getLoc();
5445 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5446 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
5447 return success();
5448 }
5449};
5450} // namespace
5451
5452void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5453 MLIRContext *context) {
5454 results.add<TransferReadAfterWriteToBroadcast>(context);
5455}
5456
5457FailureOr<std::optional<SmallVector<Value>>>
5458TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5459 if (!hasPureBufferSemantics())
5460 return failure();
5462 getResult());
5463}
5464
5465//===----------------------------------------------------------------------===//
5466// TransferWriteOp
5467//===----------------------------------------------------------------------===//
5468
5469/// 1. Builder with type inference.
5470void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5471 Value vector, Value dest, ValueRange indices,
5472 AffineMapAttr permutationMapAttr,
5473 /*optional*/ Value mask,
5474 /*optional*/ ArrayAttr inBoundsAttr) {
5475 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5476 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5477 mask, inBoundsAttr);
5478}
5479
5480/// 2. Builder with type inference that sets an empty mask (variant with attrs).
5481void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5482 Value vector, Value dest, ValueRange indices,
5483 AffineMapAttr permutationMapAttr,
5484 /*optional*/ ArrayAttr inBoundsAttr) {
5485 build(builder, result, vector, dest, indices, permutationMapAttr,
5486 /*mask=*/Value(), inBoundsAttr);
5487}
5488
5489/// 3. Builder with type inference that sets an empty mask (variant without
5490/// attrs)
5491void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5492 Value vector, Value dest, ValueRange indices,
5493 AffineMap permutationMap,
5494 std::optional<ArrayRef<bool>> inBounds) {
5495 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5496 auto inBoundsAttr =
5497 (inBounds && !inBounds.value().empty())
5498 ? builder.getBoolArrayAttr(inBounds.value())
5499 : builder.getBoolArrayAttr(SmallVector<bool>(
5500 llvm::cast<VectorType>(vector.getType()).getRank(), false));
5501 build(builder, result, vector, dest, indices, permutationMapAttr,
5502 /*mask=*/Value(), inBoundsAttr);
5503}
5504
5505/// 4. Builder with type inference that sets an empty mask and sets permutation
5506/// map to 'getMinorIdentityMap'.
5507void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5508 Value vector, Value dest, ValueRange indices,
5509 std::optional<ArrayRef<bool>> inBounds) {
5510 auto vectorType = llvm::cast<VectorType>(vector.getType());
5511 AffineMap permutationMap = getTransferMinorIdentityMap(
5512 llvm::cast<ShapedType>(dest.getType()), vectorType);
5513 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5514}
5515
5516ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5517 OperationState &result) {
5518 auto &builder = parser.getBuilder();
5519 SMLoc typesLoc;
5520 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5521 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5522 SmallVector<Type, 2> types;
5523 OpAsmParser::UnresolvedOperand maskInfo;
5524 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5525 parser.parseOperand(sourceInfo) ||
5526 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5527 return failure();
5528 ParseResult hasMask = parser.parseOptionalComma();
5529 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5530 return failure();
5531 if (parser.parseOptionalAttrDict(result.attributes) ||
5532 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5533 return failure();
5534 if (types.size() != 2)
5535 return parser.emitError(typesLoc, "requires two types");
5536 auto indexType = builder.getIndexType();
5537 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5538 if (!vectorType)
5539 return parser.emitError(typesLoc, "requires vector type");
5540 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5541 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5542 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5543 auto permMapAttrName =
5544 TransferWriteOp::getPermutationMapAttrName(result.name);
5545 auto permMapAttr = result.attributes.get(permMapAttrName);
5546 AffineMap permMap;
5547 if (!permMapAttr) {
5548 if (shapedType.getRank() <
5549 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5550 return parser.emitError(typesLoc,
5551 "expected a custom permutation_map when "
5552 "rank(source) != rank(destination)");
5553 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5554 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5555 } else {
5556 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5557 }
5558 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5559 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5560 if (!inBoundsAttr) {
5561 result.addAttribute(inBoundsAttrName,
5562 builder.getBoolArrayAttr(
5563 SmallVector<bool>(permMap.getNumResults(), false)));
5564 }
5565 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5566 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5567 parser.resolveOperands(indexInfo, indexType, result.operands))
5568 return failure();
5569 if (hasMask.succeeded()) {
5570 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5571 return parser.emitError(
5572 maskInfo.location, "does not support masks with vector element type");
5573 if (vectorType.getRank() != permMap.getNumResults()) {
5574 return parser.emitError(typesLoc,
5575 "expected the same rank for the vector and the "
5576 "results of the permutation map");
5577 }
5578 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5579 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5580 return failure();
5581 }
5582 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5583 builder.getDenseI32ArrayAttr(
5584 {1, 1, static_cast<int32_t>(indexInfo.size()),
5585 static_cast<int32_t>(hasMask.succeeded())}));
5586 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5587 parser.addTypeToList(shapedType, result.types));
5588}
5589
5590void TransferWriteOp::print(OpAsmPrinter &p) {
5591 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5592 if (getMask())
5593 p << ", " << getMask();
5594 printTransferAttrs(p, *this);
5595 p << " : " << getVectorType() << ", " << getShapedType();
5596}
5597
5598LogicalResult TransferWriteOp::verify() {
5599 // Consistency of elemental types in shape and vector.
5600 ShapedType shapedType = getShapedType();
5601 VectorType vectorType = getVectorType();
5602 VectorType maskType = getMaskType();
5603 auto permutationMap = getPermutationMap();
5604 VectorType inferredMaskType =
5605 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5606 : VectorType();
5607
5608 if (llvm::size(getIndices()) != shapedType.getRank())
5609 return emitOpError("requires ") << shapedType.getRank() << " indices";
5610
5611 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5612 // as the semantics is unclear. This can be revisited later if necessary.
5613 if (hasBroadcastDim())
5614 return emitOpError("should not have broadcast dimensions");
5615
5616 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5617 shapedType, vectorType, maskType,
5618 inferredMaskType, permutationMap, getInBounds())))
5619 return failure();
5620
5621 return verifyPermutationMap(permutationMap,
5622 [&](Twine t) { return emitOpError(t); });
5623}
5624
5625//===----------------------------------------------------------------------===//
5626// TransferWriteOp: MaskableOpInterface methods.
5627//===----------------------------------------------------------------------===//
5628
5629/// Returns the mask type expected by this operation. Mostly used for
5630/// verification purposes.
5631Type TransferWriteOp::getExpectedMaskType() {
5632 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5633}
5634
5635//===----------------------------------------------------------------------===//
5636// TransferWriteOp: VectorTransferOpInterface methods.
5637//===----------------------------------------------------------------------===//
5638Value TransferWriteOp::getVector() { return getOperand(0); }
5639VectorType TransferWriteOp::getVectorType() {
5640 return cast<VectorType>(getValueToStore().getType());
5641}
5642
5643//===----------------------------------------------------------------------===//
5644// TransferWriteOp: fold methods.
5645//===----------------------------------------------------------------------===//
5646/// Fold:
5647/// ```
5648/// %t1 = ...
5649/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5650/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5651/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5652/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5653/// ```
5654///
5655/// into:
5656///
5657/// ```
5658/// %t0
5659/// ```
5660///
5661/// The producer of t1 may or may not be DCE'd depending on whether it is a
5662/// block argument or has side effects.
5663static LogicalResult foldReadInitWrite(TransferWriteOp write,
5664 ArrayRef<Attribute>,
5665 SmallVectorImpl<OpFoldResult> &results) {
5666 // TODO: support 0-d corner case.
5667 if (write.getTransferRank() == 0)
5668 return failure();
5669 auto rankedTensorType =
5670 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5671 // If not operating on tensors, bail.
5672 if (!rankedTensorType)
5673 return failure();
5674 // If no read, bail.
5675 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5676 if (!read)
5677 return failure();
5678 // TODO: support 0-d corner case.
5679 if (read.getTransferRank() == 0)
5680 return failure();
5681 // For now, only accept minor identity. Future: composition is minor identity.
5682 if (!read.getPermutationMap().isMinorIdentity() ||
5683 !write.getPermutationMap().isMinorIdentity())
5684 return failure();
5685 // Bail on mismatching ranks.
5686 if (read.getTransferRank() != write.getTransferRank())
5687 return failure();
5688 // Bail on potential out-of-bounds accesses.
5689 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5690 return failure();
5691 // Tensor types must be the same.
5692 if (read.getBase().getType() != rankedTensorType)
5693 return failure();
5694 // Vector types must be the same.
5695 if (read.getVectorType() != write.getVectorType())
5696 return failure();
5697 // Vector and Tensor shapes must match.
5698 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5699 return failure();
5700 // If any index is nonzero.
5701 auto isNotConstantZero = [](Value v) {
5702 auto cstOp = getConstantIntValue(v);
5703 return !cstOp.has_value() || cstOp.value() != 0;
5704 };
5705 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5706 llvm::any_of(write.getIndices(), isNotConstantZero))
5707 return failure();
5708 // Success.
5709 results.push_back(read.getBase());
5710 return success();
5711}
5712
5713static bool checkSameValueWAR(vector::TransferReadOp read,
5714 vector::TransferWriteOp write) {
5715 return read.getBase() == write.getBase() &&
5716 read.getIndices() == write.getIndices() &&
5717 read.getPermutationMap() == write.getPermutationMap() &&
5718 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5719 !write.getMask();
5720}
5721/// Fold transfer_write write after read:
5722/// ```
5723/// %t0 = ...
5724/// %v = vector.transfer_read %t0[%c0...] :
5725/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5726/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5727/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5728/// ```
5729///
5730/// into:
5731///
5732/// ```
5733/// %t0
5734/// ```
5735static LogicalResult foldWAR(TransferWriteOp write,
5736 SmallVectorImpl<OpFoldResult> &results) {
5737 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5738 return failure();
5739 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5740 if (!read)
5741 return failure();
5742
5743 if (!checkSameValueWAR(read, write))
5744 return failure();
5745 results.push_back(read.getBase());
5746 return success();
5747}
5748
5749LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5750 SmallVectorImpl<OpFoldResult> &results) {
5751 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5752 return success();
5753 if (succeeded(foldWAR(*this, results)))
5754 return success();
5755 if (succeeded(foldTransferInBoundsAttribute(*this)))
5756 return success();
5757 if (succeeded(foldTransferFullMask(*this)))
5758 return success();
5759 return memref::foldMemRefCast(*this);
5760}
5761
5762//===----------------------------------------------------------------------===//
5763// TransferWriteOp: other methods.
5764//===----------------------------------------------------------------------===//
5765std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5766 return llvm::to_vector<4>(getVectorType().getShape());
5767}
5768
5769void TransferWriteOp::getEffects(
5770 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5771 &effects) {
5772 if (llvm::isa<MemRefType>(getShapedType()))
5773 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5774 SideEffects::DefaultResource::get());
5775}
5776
5777Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5778 if (hasPureTensorSemantics())
5781}
5782
5783namespace {
5784/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5785/// DCE
5786/// ```
5787/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5788/// : vector<1x4xf32>, tensor<4x4xf32>
5789/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5790/// : vector<1x4xf32>, tensor<4x4xf32>
5791/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5792/// : vector<1x4xf32>, tensor<4x4xf32>
5793/// ```
5794///
5795/// into:
5796///
5797/// ```
5798/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5799/// : vector<1x4xf32>, tensor<4x4xf32>
5800/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5801/// : vector<1x4xf32>, tensor<4x4xf32>
5802/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5803/// : vector<1x4xf32>, tensor<4x4xf32>
5804/// ```
5805///
5806/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5807/// any other uses.
5808class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5809public:
5810 using Base::Base;
5811 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5812 PatternRewriter &rewriter) const override {
5813 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5814 return failure();
5815 vector::TransferWriteOp writeToModify = writeOp;
5816
5817 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5818 while (defWrite) {
5819 if (checkSameValueWAW(writeOp, defWrite)) {
5820 rewriter.modifyOpInPlace(writeToModify, [&]() {
5821 writeToModify.getBaseMutable().assign(defWrite.getBase());
5822 });
5823 return success();
5824 }
5826 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5827 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5828 break;
5829 // If the previous write op doesn't have any other use we an safely look
5830 // at the previous store to see if it can be removed.
5831 if (!defWrite->hasOneUse())
5832 break;
5833 writeToModify = defWrite;
5834 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5835 }
5836 return failure();
5837 }
5838};
5839
5840/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5841/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5842/// overwritten and inserted into another tensor. After this rewrite, the
5843/// operations bufferize in-place since all of them work on the same slice.
5844///
5845/// For example:
5846/// ```mlir
5847/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5848/// : vector<8x16xf32>, tensor<8x16xf32>
5849/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5850/// : tensor<8x16xf32> to tensor<?x?xf32>
5851/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5852/// : tensor<?x?xf32> into tensor<27x37xf32>
5853/// ```
5854/// folds to
5855/// ```mlir
5856/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5857/// : tensor<27x37xf32> to tensor<?x?xf32>
5858/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5859/// : vector<8x16xf32>, tensor<?x?xf32>
5860/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5861/// : tensor<?x?xf32> into tensor<27x37xf32>
5862/// ```
5863struct SwapExtractSliceOfTransferWrite
5864 : public OpRewritePattern<tensor::InsertSliceOp> {
5865public:
5866 using Base::Base;
5867
5868 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5869 PatternRewriter &rewriter) const override {
5870 if (!insertOp.hasUnitStride())
5871 return failure();
5872 auto extractOp =
5873 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5874 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5875 return failure();
5876 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5877 if (!transferOp || !transferOp->hasOneUse())
5878 return failure();
5879
5880 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5881 // rank-reducing.
5882 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5883 return rewriter.notifyMatchFailure(insertOp,
5884 "use-def chain is rank-reducing");
5885 }
5886
5887 // Fail if tensor::ExtractSliceOp has non-zero offset.
5888 if (!extractOp.hasZeroOffset()) {
5889 return rewriter.notifyMatchFailure(insertOp,
5890 "ExtractSliceOp has non-zero offset");
5891 }
5892
5893 // Fail if tensor::TransferWriteOp has non-zero offset.
5894 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5895 return getConstantIntValue(value) == static_cast<int64_t>(0);
5896 })) {
5897 return rewriter.notifyMatchFailure(insertOp,
5898 "TranferWriteOp has non-zero offset");
5899 }
5900
5901 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5902 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5903 return rewriter.notifyMatchFailure(
5904 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5905 }
5906
5907 for (auto [insertSize, extractSize] :
5908 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5909 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5910 return rewriter.notifyMatchFailure(
5911 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5912 }
5913 }
5914
5915 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5916 assert(transferOp.getVectorType().hasStaticShape() &&
5917 "expected vector to have a static shape");
5918 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5919 SmallVector<int64_t> resultShape = applyPermutationMap(
5920 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5921 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5922 return rewriter.notifyMatchFailure(
5923 insertOp, "TransferWriteOp may not write the full tensor.");
5924 }
5925
5926 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5927 // Set all in_bounds to false and let the folder infer them.
5928 SmallVector<bool> newInBounds(vectorShape.size(), false);
5929 auto newExtractOp = tensor::ExtractSliceOp::create(
5930 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5931 insertOp.getDest(), insertOp.getMixedOffsets(),
5932 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5933 auto newTransferWriteOp = TransferWriteOp::create(
5934 rewriter, transferOp.getLoc(), transferOp.getVector(),
5935 newExtractOp.getResult(), transferOp.getIndices(),
5936 transferOp.getPermutationMapAttr(),
5937 rewriter.getBoolArrayAttr(newInBounds));
5938 rewriter.modifyOpInPlace(insertOp, [&]() {
5939 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5940 });
5941 return success();
5942 }
5943};
5944
5945} // namespace
5946
5947void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5948 MLIRContext *context) {
5949 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5950}
5951
5952FailureOr<std::optional<SmallVector<Value>>>
5953TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5954 if (!hasPureBufferSemantics())
5955 return failure();
5957 ValueRange());
5958}
5959
5960//===----------------------------------------------------------------------===//
5961// LoadOp
5962//===----------------------------------------------------------------------===//
5963
5964static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5965 VectorType vecTy,
5966 MemRefType memRefTy) {
5967 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5968 // need any strides limitations.
5969 if (!vecTy.isScalable() &&
5970 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5971 return success();
5972
5973 if (!memRefTy.isLastDimUnitStride())
5974 return op->emitOpError("most minor memref dim must have unit stride");
5975 return success();
5976}
5977
5978LogicalResult vector::LoadOp::verify() {
5979 VectorType resVecTy = getVectorType();
5980 MemRefType memRefTy = getMemRefType();
5981
5982 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5983 return failure();
5984
5985 if (memRefTy.getRank() < resVecTy.getRank())
5986 return emitOpError(
5987 "destination memref has lower rank than the result vector");
5988
5989 // Checks for vector memrefs.
5990 Type memElemTy = memRefTy.getElementType();
5991 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5992 if (memVecTy != resVecTy)
5993 return emitOpError("base memref and result vector types should match");
5994 memElemTy = memVecTy.getElementType();
5995 }
5996
5997 if (resVecTy.getElementType() != memElemTy)
5998 return emitOpError("base and result element types should match");
5999 if (llvm::size(getIndices()) != memRefTy.getRank())
6000 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6001 return success();
6002}
6003
6004OpFoldResult LoadOp::fold(FoldAdaptor) {
6005 if (succeeded(memref::foldMemRefCast(*this)))
6006 return getResult();
6007 return OpFoldResult();
6008}
6009
6010std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6011 return llvm::to_vector<4>(getVectorType().getShape());
6012}
6013
6014FailureOr<std::optional<SmallVector<Value>>>
6015LoadOp::bubbleDownCasts(OpBuilder &builder) {
6017 getResult());
6018}
6019
6020//===----------------------------------------------------------------------===//
6021// StoreOp
6022//===----------------------------------------------------------------------===//
6023
6024LogicalResult vector::StoreOp::verify() {
6025 VectorType valueVecTy = getVectorType();
6026 MemRefType memRefTy = getMemRefType();
6027
6028 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
6029 return failure();
6030
6031 if (memRefTy.getRank() < valueVecTy.getRank())
6032 return emitOpError("source memref has lower rank than the vector to store");
6033
6034 // Checks for vector memrefs.
6035 Type memElemTy = memRefTy.getElementType();
6036 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6037 if (memVecTy != valueVecTy)
6038 return emitOpError(
6039 "base memref and valueToStore vector types should match");
6040 memElemTy = memVecTy.getElementType();
6041 }
6042
6043 if (valueVecTy.getElementType() != memElemTy)
6044 return emitOpError("base and valueToStore element type should match");
6045 if (llvm::size(getIndices()) != memRefTy.getRank())
6046 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6047 return success();
6048}
6049
6050LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6051 SmallVectorImpl<OpFoldResult> &results) {
6052 return memref::foldMemRefCast(*this);
6053}
6054
6055std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6056 return llvm::to_vector<4>(getVectorType().getShape());
6057}
6058
6059FailureOr<std::optional<SmallVector<Value>>>
6060StoreOp::bubbleDownCasts(OpBuilder &builder) {
6062 ValueRange());
6063}
6064
6065//===----------------------------------------------------------------------===//
6066// MaskedLoadOp
6067//===----------------------------------------------------------------------===//
6068
6069LogicalResult MaskedLoadOp::verify() {
6070 VectorType maskVType = getMaskVectorType();
6071 VectorType passVType = getPassThruVectorType();
6072 VectorType resVType = getVectorType();
6073 MemRefType memType = getMemRefType();
6074
6075 if (failed(
6076 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6077 return failure();
6078 if (llvm::size(getIndices()) != memType.getRank())
6079 return emitOpError("requires ") << memType.getRank() << " indices";
6080 if (resVType.getShape() != maskVType.getShape())
6081 return emitOpError("expected result shape to match mask shape");
6082 if (resVType != passVType)
6083 return emitOpError("expected pass_thru of same type as result type");
6084 return success();
6085}
6086
6087namespace {
6088class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
6089public:
6090 using Base::Base;
6091 LogicalResult matchAndRewrite(MaskedLoadOp load,
6092 PatternRewriter &rewriter) const override {
6093 switch (getMaskFormat(load.getMask())) {
6095 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6096 load, load.getType(), load.getBase(), load.getIndices());
6097 return success();
6099 rewriter.replaceOp(load, load.getPassThru());
6100 return success();
6102 return failure();
6103 }
6104 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
6105 }
6106};
6107} // namespace
6108
6109void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6110 MLIRContext *context) {
6111 results.add<MaskedLoadFolder>(context);
6112}
6113
6114OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6115 if (succeeded(memref::foldMemRefCast(*this)))
6116 return getResult();
6117 return OpFoldResult();
6118}
6119
6120FailureOr<std::optional<SmallVector<Value>>>
6121MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6123 getResult());
6124}
6125
6126//===----------------------------------------------------------------------===//
6127// MaskedStoreOp
6128//===----------------------------------------------------------------------===//
6129
6130LogicalResult MaskedStoreOp::verify() {
6131 VectorType maskVType = getMaskVectorType();
6132 VectorType valueVType = getVectorType();
6133 MemRefType memType = getMemRefType();
6134
6135 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6136 "valueToStore")))
6137 return failure();
6138 if (llvm::size(getIndices()) != memType.getRank())
6139 return emitOpError("requires ") << memType.getRank() << " indices";
6140 if (valueVType.getShape() != maskVType.getShape())
6141 return emitOpError("expected valueToStore shape to match mask shape");
6142 return success();
6143}
6144
6145namespace {
6146class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
6147public:
6148 using Base::Base;
6149 LogicalResult matchAndRewrite(MaskedStoreOp store,
6150 PatternRewriter &rewriter) const override {
6151 switch (getMaskFormat(store.getMask())) {
6153 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6154 store, store.getValueToStore(), store.getBase(), store.getIndices());
6155 return success();
6157 rewriter.eraseOp(store);
6158 return success();
6160 return failure();
6161 }
6162 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
6163 }
6164};
6165} // namespace
6166
6167void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6168 MLIRContext *context) {
6169 results.add<MaskedStoreFolder>(context);
6170}
6171
6172LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6173 SmallVectorImpl<OpFoldResult> &results) {
6174 return memref::foldMemRefCast(*this);
6175}
6176
6177FailureOr<std::optional<SmallVector<Value>>>
6178MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6180 ValueRange());
6181}
6182
6183//===----------------------------------------------------------------------===//
6184// GatherOp
6185//===----------------------------------------------------------------------===//
6186
6187LogicalResult GatherOp::verify() {
6188 VectorType indVType = getIndexVectorType();
6189 VectorType maskVType = getMaskVectorType();
6190 VectorType resVType = getVectorType();
6191 ShapedType baseType = getBaseType();
6192
6193 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6194 return emitOpError("requires base to be a memref or ranked tensor type");
6195
6196 if (failed(
6197 verifyElementTypesMatch(*this, baseType, resVType, "base", "result")))
6198 return failure();
6199 if (llvm::size(getOffsets()) != baseType.getRank())
6200 return emitOpError("requires ") << baseType.getRank() << " indices";
6201 if (resVType.getShape() != indVType.getShape())
6202 return emitOpError("expected result dim to match indices dim");
6203 if (resVType.getShape() != maskVType.getShape())
6204 return emitOpError("expected result dim to match mask dim");
6205 if (resVType != getPassThruVectorType())
6206 return emitOpError("expected pass_thru of same type as result type");
6207 return success();
6208}
6209
6210// MaskableOpInterface methods.
6211
6212/// Returns the mask type expected by this operation. Mostly used for
6213/// verification purposes. It requires the operation to be vectorized."
6214Type GatherOp::getExpectedMaskType() {
6215 auto vecType = this->getIndexVectorType();
6216 return VectorType::get(vecType.getShape(),
6217 IntegerType::get(vecType.getContext(), /*width=*/1),
6218 vecType.getScalableDims());
6219}
6220
6221std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6222 return llvm::to_vector<4>(getVectorType().getShape());
6223}
6224
6225/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
6226static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6227 auto vecType = dyn_cast<VectorType>(indexVec.getType());
6228 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6229 return failure();
6230
6231 if (indexVec.getDefiningOp<StepOp>())
6232 return success();
6233
6234 DenseIntElementsAttr elements;
6235 if (!matchPattern(indexVec, m_Constant(&elements)))
6236 return failure();
6237
6238 return success(
6239 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6240}
6241
6242namespace {
6243class GatherFolder final : public OpRewritePattern<GatherOp> {
6244public:
6245 using Base::Base;
6246 LogicalResult matchAndRewrite(GatherOp gather,
6247 PatternRewriter &rewriter) const override {
6248 switch (getMaskFormat(gather.getMask())) {
6250 return failure(); // no unmasked equivalent
6252 rewriter.replaceOp(gather, gather.getPassThru());
6253 return success();
6255 return failure();
6256 }
6257 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6258 }
6259};
6260
6261/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6262/// maskedload. Only 1D fixed vectors are supported for now.
6263class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6264public:
6265 using Base::Base;
6266 LogicalResult matchAndRewrite(GatherOp op,
6267 PatternRewriter &rewriter) const override {
6268 if (!isa<MemRefType>(op.getBase().getType()))
6269 return rewriter.notifyMatchFailure(op, "base must be of memref type");
6270
6271 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6272 return failure();
6273
6274 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6275 op.getOffsets(), op.getMask(),
6276 op.getPassThru());
6277 return success();
6278 }
6279};
6280} // namespace
6281
6282void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6283 MLIRContext *context) {
6284 results.add<GatherFolder, FoldContiguousGather>(context);
6285}
6286
6287FailureOr<std::optional<SmallVector<Value>>>
6288GatherOp::bubbleDownCasts(OpBuilder &builder) {
6290 getResult());
6291}
6292
6293//===----------------------------------------------------------------------===//
6294// ScatterOp
6295//===----------------------------------------------------------------------===//
6296
6297LogicalResult ScatterOp::verify() {
6298 VectorType indVType = getIndexVectorType();
6299 VectorType maskVType = getMaskVectorType();
6300 VectorType valueVType = getVectorType();
6301 ShapedType baseType = getBaseType();
6302
6303 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6304 return emitOpError("requires base to be a memref or ranked tensor type");
6305
6306 if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
6307 "valueToStore")))
6308 return failure();
6309 if (llvm::size(getOffsets()) != baseType.getRank())
6310 return emitOpError("requires ") << baseType.getRank() << " indices";
6311 if (valueVType.getShape() != indVType.getShape())
6312 return emitOpError("expected valueToStore dim to match indices dim");
6313 if (valueVType.getShape() != maskVType.getShape())
6314 return emitOpError("expected valueToStore dim to match mask dim");
6315 return success();
6316}
6317namespace {
6318class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6319public:
6320 using Base::Base;
6321 LogicalResult matchAndRewrite(ScatterOp scatter,
6322 PatternRewriter &rewriter) const override {
6323 ShapedType baseType = scatter.getBaseType();
6324 bool isMemRef = isa<MemRefType>(baseType);
6325 if (!isMemRef && !isa<RankedTensorType>(baseType))
6326 return failure();
6327
6328 // Memrefs have no result, so an all-false mask can simply erase the op.
6329 // Tensors carry the updated value, so we must replace uses with the
6330 // original base tensor instead of erasing.
6331 switch (getMaskFormat(scatter.getMask())) {
6333 return failure(); // no unmasked equivalent
6335 if (isMemRef)
6336 rewriter.eraseOp(scatter);
6337 else
6338 rewriter.replaceOp(scatter, scatter.getBase());
6339 return success();
6341 return failure();
6342 }
6343 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6344 }
6345};
6346
6347/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6348/// maskedstore. Only 1D fixed vectors are supported for now.
6349class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6350public:
6351 using Base::Base;
6352 LogicalResult matchAndRewrite(ScatterOp op,
6353 PatternRewriter &rewriter) const override {
6354 // Fold only for memrefs: the replacement uses maskedstore, which does not
6355 // support tensor bases. Tensor cases intentionally bail out.
6356 if (!isa<MemRefType>(op.getBase().getType()))
6357 return failure();
6358
6359 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6360 return failure();
6361
6362 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6363 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6364 return success();
6365 }
6366};
6367} // namespace
6368
6369void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6370 MLIRContext *context) {
6371 results.add<ScatterFolder, FoldContiguousScatter>(context);
6372}
6373
6374FailureOr<std::optional<SmallVector<Value>>>
6375ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6377 ValueRange());
6378}
6379
6380//===----------------------------------------------------------------------===//
6381// ExpandLoadOp
6382//===----------------------------------------------------------------------===//
6383
6384LogicalResult ExpandLoadOp::verify() {
6385 VectorType maskVType = getMaskVectorType();
6386 VectorType passVType = getPassThruVectorType();
6387 VectorType resVType = getVectorType();
6388 MemRefType memType = getMemRefType();
6389
6390 if (failed(
6391 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6392 return failure();
6393 if (llvm::size(getIndices()) != memType.getRank())
6394 return emitOpError("requires ") << memType.getRank() << " indices";
6395 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6396 return emitOpError("expected result dim to match mask dim");
6397 if (resVType != passVType)
6398 return emitOpError("expected pass_thru of same type as result type");
6399 return success();
6400}
6401
6402namespace {
6403class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6404public:
6405 using Base::Base;
6406 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6407 PatternRewriter &rewriter) const override {
6408 switch (getMaskFormat(expand.getMask())) {
6410 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6411 expand, expand.getType(), expand.getBase(), expand.getIndices());
6412 return success();
6414 rewriter.replaceOp(expand, expand.getPassThru());
6415 return success();
6417 return failure();
6418 }
6419 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6420 }
6421};
6422} // namespace
6423
6424void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6425 MLIRContext *context) {
6426 results.add<ExpandLoadFolder>(context);
6427}
6428
6429FailureOr<std::optional<SmallVector<Value>>>
6430ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6432 getResult());
6433}
6434
6435//===----------------------------------------------------------------------===//
6436// CompressStoreOp
6437//===----------------------------------------------------------------------===//
6438
6439LogicalResult CompressStoreOp::verify() {
6440 VectorType maskVType = getMaskVectorType();
6441 VectorType valueVType = getVectorType();
6442 MemRefType memType = getMemRefType();
6443
6444 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6445 "valueToStore")))
6446 return failure();
6447 if (llvm::size(getIndices()) != memType.getRank())
6448 return emitOpError("requires ") << memType.getRank() << " indices";
6449 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6450 return emitOpError("expected valueToStore dim to match mask dim");
6451 return success();
6452}
6453
6454namespace {
6455class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6456public:
6457 using Base::Base;
6458 LogicalResult matchAndRewrite(CompressStoreOp compress,
6459 PatternRewriter &rewriter) const override {
6460 switch (getMaskFormat(compress.getMask())) {
6462 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6463 compress, compress.getValueToStore(), compress.getBase(),
6464 compress.getIndices());
6465 return success();
6467 rewriter.eraseOp(compress);
6468 return success();
6470 return failure();
6471 }
6472 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6473 }
6474};
6475} // namespace
6476
6477void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6478 MLIRContext *context) {
6479 results.add<CompressStoreFolder>(context);
6480}
6481
6482FailureOr<std::optional<SmallVector<Value>>>
6483CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6485 ValueRange());
6486}
6487
6488//===----------------------------------------------------------------------===//
6489// ShapeCastOp
6490//===----------------------------------------------------------------------===//
6491
6492void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6493 SetIntRangeFn setResultRanges) {
6494 setResultRanges(getResult(), argRanges.front());
6495}
6496
6497std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6498 return llvm::to_vector<4>(getResultVectorType().getShape());
6499}
6500
6501LogicalResult ShapeCastOp::verify() {
6502
6503 VectorType sourceType = getSourceVectorType();
6504 VectorType resultType = getResultVectorType();
6505
6506 // Check that element type is preserved
6507 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
6508 "result")))
6509 return failure();
6510
6511 // Check that number of elements is preserved
6512 int64_t sourceNElms = sourceType.getNumElements();
6513 int64_t resultNElms = resultType.getNumElements();
6514 if (sourceNElms != resultNElms) {
6515 return emitOpError() << "has different number of elements at source ("
6516 << sourceNElms << ") and result (" << resultNElms
6517 << ")";
6518 }
6519
6520 // Check that (non-)scalability is preserved
6521 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6522 int64_t resultNScalableDims = resultType.getNumScalableDims();
6523 if (sourceNScalableDims != resultNScalableDims)
6524 return emitOpError() << "has different number of scalable dims at source ("
6525 << sourceNScalableDims << ") and result ("
6526 << resultNScalableDims << ")";
6527
6528 return success();
6529}
6530
6531/// Return true if `transpose` does not permute a pair of non-unit dims.
6532/// By `order preserving` we mean that the flattened versions of the input and
6533/// output vectors are (numerically) identical. In other words `transpose` is
6534/// effectively a shape cast.
6535static bool isOrderPreserving(TransposeOp transpose) {
6536 ArrayRef<int64_t> permutation = transpose.getPermutation();
6537 VectorType sourceType = transpose.getSourceVectorType();
6538 ArrayRef<int64_t> inShape = sourceType.getShape();
6539 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6540 auto isNonScalableUnitDim = [&](int64_t dim) {
6541 return inShape[dim] == 1 && !inDimIsScalable[dim];
6542 };
6543 int64_t current = 0;
6544 for (auto p : permutation) {
6545 if (!isNonScalableUnitDim(p)) {
6546 if (p < current) {
6547 return false;
6548 }
6549 current = p;
6550 }
6551 }
6552 return true;
6553}
6554
6555OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6556
6557 VectorType resultType = getType();
6558
6559 // No-op shape cast.
6560 if (getSource().getType() == resultType)
6561 return getSource();
6562
6563 // shape_cast(shape_cast(x)) -> shape_cast(x)
6564 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6565 setOperand(precedingShapeCast.getSource());
6566 return getResult();
6567 }
6568
6569 // shape_cast(transpose(x)) -> shape_cast(x)
6570 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6571 if (isOrderPreserving(transpose)) {
6572 setOperand(transpose.getVector());
6573 return getResult();
6574 }
6575 return {};
6576 }
6577
6578 // Y = shape_cast(broadcast(X))
6579 // -> X, if X and Y have same type
6580 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6581 if (bcastOp.getSourceType() == resultType)
6582 return bcastOp.getSource();
6583 }
6584
6585 // shape_cast(constant) -> constant
6586 if (auto denseAttr =
6587 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6588 return denseAttr.reshape(getType());
6589
6590 // shape_cast(poison) -> poison
6591 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
6592 return ub::PoisonAttr::get(getContext());
6593
6594 return {};
6595}
6596
6597namespace {
6598
6599/// Helper function that computes a new vector type based on the input vector
6600/// type by removing the trailing one dims:
6601///
6602/// vector<4x1x1xi1> --> vector<4x1xi1>
6603///
6604static VectorType trimTrailingOneDims(VectorType oldType) {
6605 ArrayRef<int64_t> oldShape = oldType.getShape();
6606 ArrayRef<int64_t> newShape = oldShape;
6607
6608 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6609 ArrayRef<bool> newScalableDims = oldScalableDims;
6610
6611 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6612 newShape = newShape.drop_back(1);
6613 newScalableDims = newScalableDims.drop_back(1);
6614 }
6615
6616 // Make sure we have at least 1 dimension.
6617 // TODO: Add support for 0-D vectors.
6618 if (newShape.empty()) {
6619 newShape = oldShape.take_back();
6620 newScalableDims = oldScalableDims.take_back();
6621 }
6622
6623 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6624}
6625
6626/// Folds qualifying shape_cast(create_mask) into a new create_mask
6627///
6628/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6629/// dimension. If the input vector comes from `vector.create_mask` for which
6630/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6631/// to fold shape_cast into create_mask.
6632///
6633/// BEFORE:
6634/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6635/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6636/// AFTER:
6637/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6638class ShapeCastCreateMaskFolderTrailingOneDim final
6639 : public OpRewritePattern<ShapeCastOp> {
6640public:
6641 using Base::Base;
6642
6643 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6644 PatternRewriter &rewriter) const override {
6645 Value shapeOpSrc = shapeOp->getOperand(0);
6646 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6647 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6648 if (!createMaskOp && !constantMaskOp)
6649 return failure();
6650
6651 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6652 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6653
6654 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6655 if (newVecType != shapeOpResTy)
6656 return failure();
6657
6658 auto numDimsToDrop =
6659 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6660
6661 // No unit dims to drop
6662 if (!numDimsToDrop)
6663 return failure();
6664
6665 if (createMaskOp) {
6666 auto maskOperands = createMaskOp.getOperands();
6667 auto numMaskOperands = maskOperands.size();
6668
6669 // Check every mask dim size to see whether it can be dropped
6670 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6671 --i) {
6672 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6673 if (!constant || (constant.value() != 1))
6674 return failure();
6675 }
6676 SmallVector<Value> newMaskOperands =
6677 maskOperands.drop_back(numDimsToDrop);
6678
6679 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6680 newMaskOperands);
6681 return success();
6682 }
6683
6684 if (constantMaskOp) {
6685 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6686 auto numMaskOperands = maskDimSizes.size();
6687
6688 // Check every mask dim size to see whether it can be dropped
6689 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6690 --i) {
6691 if (maskDimSizes[i] != 1)
6692 return failure();
6693 }
6694
6695 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6696 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6697 newMaskOperands);
6698 return success();
6699 }
6700
6701 return failure();
6702 }
6703};
6704
6705/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
6706class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6707public:
6708 using Base::Base;
6709
6710 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6711 PatternRewriter &rewriter) const override {
6712 auto broadcastOp =
6713 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6714 if (!broadcastOp)
6715 return failure();
6716
6717 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6718 bool srcIsScalar = !srcVectorType;
6719
6720 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6721 // Example
6722 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6723 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6724 // to
6725 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6726 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6727 if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
6728 BroadcastableToResult::Success) {
6729 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6730 shapeCastOp, dstVectorType, broadcastOp.getSource());
6731 return success();
6732 }
6733 return failure();
6734 }
6735};
6736
6737/// Pattern to rewrite Y = ShapeCast(FromElements(X)) as Y = FromElements(X)
6738///
6739/// BEFORE:
6740/// %1 = vector.from_elements %c1, %c2, %c3 : vector<3xf32>
6741/// %2 = vector.shape_cast %1 : vector<3xf32> to vector<1x3xf32>
6742/// AFTER:
6743/// %2 = vector.from_elements %c1, %c2, %c3 : vector<1x3xf32>
6744///
6745/// Note: this transformation is implemented as an OpRewritePattern, not as a
6746/// fold, because we have to create new op FromElementsOp with updated result
6747/// type. This cannot be done with a fold, because fold cannot create new ops
6748/// and the existing FromElementsOp result type differs from the ShapeCastOp
6749/// result type. Mutating the FromElementsOp (not root op) would violate the
6750/// fold contract and break other users.
6751class FoldShapeCastOfFromElements final : public OpRewritePattern<ShapeCastOp> {
6752public:
6753 using Base::Base;
6754
6755 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6756 PatternRewriter &rewriter) const override {
6757 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
6758 if (!fromElements)
6759 return failure();
6760
6761 rewriter.replaceOpWithNewOp<FromElementsOp>(
6762 shapeCastOp, shapeCastOp.getResultVectorType(),
6763 fromElements.getElements());
6764 return success();
6765 }
6766};
6767
6768} // namespace
6769
6770void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6771 MLIRContext *context) {
6772 results.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
6773 FoldShapeCastOfFromElements>(context);
6774}
6775
6776//===----------------------------------------------------------------------===//
6777// VectorBitCastOp
6778//===----------------------------------------------------------------------===//
6779
6780LogicalResult BitCastOp::verify() {
6781 auto sourceVectorType = getSourceVectorType();
6782 auto resultVectorType = getResultVectorType();
6783
6784 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6785 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6786 return emitOpError("dimension size mismatch at: ") << i;
6787 }
6788
6789 DataLayout dataLayout = DataLayout::closest(*this);
6790 auto sourceElementBits =
6791 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
6792 auto resultElementBits =
6793 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
6794
6795 if (sourceVectorType.getRank() == 0) {
6796 if (sourceElementBits != resultElementBits)
6797 return emitOpError("source/result bitwidth of the 0-D vector element "
6798 "types must be equal");
6799 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6800 resultElementBits * resultVectorType.getShape().back()) {
6801 return emitOpError(
6802 "source/result bitwidth of the minor 1-D vectors must be equal");
6803 }
6804
6805 return success();
6806}
6807
6808OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6809 // Nop cast.
6810 if (getSource().getType() == getResult().getType())
6811 return getSource();
6812
6813 // Canceling bitcasts.
6814 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6815 if (getResult().getType() == otherOp.getSource().getType())
6816 return otherOp.getSource();
6817
6818 setOperand(otherOp.getSource());
6819 return getResult();
6820 }
6821
6822 Attribute sourceConstant = adaptor.getSource();
6823 if (!sourceConstant)
6824 return {};
6825
6826 Type srcElemType = getSourceVectorType().getElementType();
6827 Type dstElemType = getResultVectorType().getElementType();
6828
6829 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6830 if (floatPack.isSplat()) {
6831 auto splat = floatPack.getSplatValue<FloatAttr>();
6832
6833 // Casting fp16 into fp32.
6834 if (srcElemType.isF16() && dstElemType.isF32()) {
6835 uint32_t bits = static_cast<uint32_t>(
6836 splat.getValue().bitcastToAPInt().getZExtValue());
6837 // Duplicate the 16-bit pattern.
6838 bits = (bits << 16) | (bits & 0xffff);
6839 APInt intBits(32, bits);
6840 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6841 return DenseElementsAttr::get(getResultVectorType(), floatBits);
6842 }
6843 }
6844 }
6845
6846 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6847 if (intPack.isSplat()) {
6848 auto splat = intPack.getSplatValue<IntegerAttr>();
6849
6850 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.isIntOrFloat()) {
6851 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6852 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6853
6854 // Casting to a larger integer bit width.
6855 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6856 APInt intBits = splat.getValue().zext(dstBitWidth);
6857
6858 // Duplicate the lower width element.
6859 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6860 intBits = (intBits << srcBitWidth) | intBits;
6861 return DenseElementsAttr::get(getResultVectorType(), intBits);
6862 }
6863 }
6864 }
6865 }
6866
6867 return {};
6868}
6869
6870//===----------------------------------------------------------------------===//
6871// TypeCastOp
6872//===----------------------------------------------------------------------===//
6873
6874static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6875 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6876 SmallVector<int64_t, 8> res(memRefType.getShape());
6877 if (vectorType)
6878 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6879 return res;
6880}
6881
6882/// Build the canonical memRefType with a single vector.
6883/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6884void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6885 Value source) {
6886 result.addOperands(source);
6887 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
6888 VectorType vectorType =
6889 VectorType::get(extractShape(memRefType),
6891 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6892 memRefType.getMemorySpace()));
6893}
6894
6895LogicalResult TypeCastOp::verify() {
6896 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6897 if (!canonicalType.getLayout().isIdentity())
6898 return emitOpError("expects operand to be a memref with identity layout");
6899 if (!getResultMemRefType().getLayout().isIdentity())
6900 return emitOpError("expects result to be a memref with identity layout");
6901 if (getResultMemRefType().getMemorySpace() !=
6902 getMemRefType().getMemorySpace())
6903 return emitOpError("expects result in same memory space");
6904
6905 auto sourceType = getMemRefType();
6906 auto resultType = getResultMemRefType();
6907 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
6909 return emitOpError(
6910 "expects result and operand with same underlying scalar type: ")
6911 << resultType;
6912 if (extractShape(sourceType) != extractShape(resultType))
6913 return emitOpError(
6914 "expects concatenated result and operand shapes to be equal: ")
6915 << resultType;
6916 return success();
6917}
6918
6919//===----------------------------------------------------------------------===//
6920// TransposeOp
6921//===----------------------------------------------------------------------===//
6922
6923void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6924 Value vector, ArrayRef<int64_t> permutation) {
6925 VectorType vt = llvm::cast<VectorType>(vector.getType());
6926 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6927 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6928 for (unsigned i = 0; i < permutation.size(); ++i) {
6929 transposedShape[i] = vt.getShape()[permutation[i]];
6930 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6931 }
6932
6933 result.addOperands(vector);
6934 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6935 transposedScalableDims));
6936 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6937 builder.getDenseI64ArrayAttr(permutation));
6938}
6939
6940OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6941 // Eliminate splat constant transpose ops.
6942 if (auto splat =
6943 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6944 return splat.reshape(getResultVectorType());
6945
6946 // Eliminate poison transpose ops.
6947 if (matchPattern(adaptor.getVector(), ub::m_Poison()))
6948 return ub::PoisonAttr::get(getContext());
6949
6950 // Eliminate identity transposes, and more generally any transposes that
6951 // preserves the shape without permuting elements.
6952 //
6953 // Examples of what to fold:
6954 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6955 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6956 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6957 //
6958 // Example of what NOT to fold:
6959 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6960 //
6961 if (getSourceVectorType() == getResultVectorType() &&
6962 isOrderPreserving(*this))
6963 return getVector();
6964
6965 return {};
6966}
6967
6968LogicalResult vector::TransposeOp::verify() {
6969 VectorType vectorType = getSourceVectorType();
6970 VectorType resultType = getResultVectorType();
6971 int64_t rank = resultType.getRank();
6972 if (vectorType.getRank() != rank)
6973 return emitOpError("vector result rank mismatch: ") << rank;
6974 // Verify transposition array.
6975 ArrayRef<int64_t> perm = getPermutation();
6976 int64_t size = perm.size();
6977 if (rank != size)
6978 return emitOpError("transposition length mismatch: ") << size;
6979 SmallVector<bool, 8> seen(rank, false);
6980 for (const auto &ta : llvm::enumerate(perm)) {
6981 if (ta.value() < 0 || ta.value() >= rank)
6982 return emitOpError("transposition index out of range: ") << ta.value();
6983 if (seen[ta.value()])
6984 return emitOpError("duplicate position index: ") << ta.value();
6985 seen[ta.value()] = true;
6986 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6987 return emitOpError("dimension size mismatch at: ") << ta.value();
6988 }
6989 return success();
6990}
6991
6992std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6993 return llvm::to_vector<4>(getResultVectorType().getShape());
6994}
6995
6996void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6997 SetIntRangeFn setResultRanges) {
6998 setResultRanges(getResult(), argRanges.front());
6999}
7000
7001namespace {
7002
7003// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
7004class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
7005public:
7006 using Base::Base;
7007
7008 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7009 PatternRewriter &rewriter) const override {
7010 // Composes two permutations: result[i] = permutation1[permutation2[i]].
7011 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7012 ArrayRef<int64_t> permutation2) {
7013 SmallVector<int64_t, 4> result;
7014 for (auto index : permutation2)
7015 result.push_back(permutation1[index]);
7016 return result;
7017 };
7018
7019 // Return if the input of 'transposeOp' is not defined by another transpose.
7020 vector::TransposeOp parentTransposeOp =
7021 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7022 if (!parentTransposeOp)
7023 return failure();
7024
7025 SmallVector<int64_t, 4> permutation = composePermutations(
7026 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7027 // Replace 'transposeOp' with a new transpose operation.
7028 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
7029 transposeOp, transposeOp.getResult().getType(),
7030 parentTransposeOp.getVector(), permutation);
7031 return success();
7032 }
7033};
7034
7035/// Replace transpose(splat-like(v)) with broadcast(v)
7036class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
7037public:
7038 using Base::Base;
7039
7040 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7041 PatternRewriter &rewriter) const override {
7042 Value splat = getScalarSplatSource(transposeOp.getVector());
7043 if (!splat)
7044 return failure();
7045
7046 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
7047 transposeOp, transposeOp.getResultVectorType(), splat);
7048 return success();
7049 }
7050};
7051
7052/// Folds transpose(create_mask) into a new transposed create_mask.
7053class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
7054public:
7055 using Base::Base;
7056
7057 LogicalResult matchAndRewrite(TransposeOp transpOp,
7058 PatternRewriter &rewriter) const override {
7059 Value transposeSrc = transpOp.getVector();
7060 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
7061 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
7062 if (!createMaskOp && !constantMaskOp)
7063 return failure();
7064
7065 // Get the transpose permutation and apply it to the vector.create_mask or
7066 // vector.constant_mask operands.
7067 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7068
7069 if (createMaskOp) {
7070 auto maskOperands = createMaskOp.getOperands();
7071 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7072 applyPermutationToVector(newOperands, permutation);
7073
7074 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
7075 transpOp, transpOp.getResultVectorType(), newOperands);
7076 return success();
7077 }
7078
7079 // ConstantMaskOp case.
7080 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7081 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
7082
7083 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
7084 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7085 return success();
7086 }
7087};
7088
7089/// Folds transpose(shape_cast) into a new shape_cast.
7090class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
7091public:
7092 using Base::Base;
7093
7094 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7095 PatternRewriter &rewriter) const override {
7096 auto shapeCastOp =
7097 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7098 if (!shapeCastOp)
7099 return failure();
7100 if (!isOrderPreserving(transposeOp))
7101 return failure();
7102
7103 VectorType resultType = transposeOp.getType();
7104
7105 // We don't need to check isValidShapeCast at this point, because it is
7106 // guaranteed that merging the transpose into the the shape_cast is a valid
7107 // shape_cast, because the transpose just inserts/removes ones.
7108
7109 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
7110 shapeCastOp.getSource());
7111 return success();
7112 }
7113};
7114
7115/// Folds transpose(from_elements(...)) into a new from_elements with permuted
7116/// operands matching the transposed shape.
7117///
7118/// Example:
7119///
7120/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
7121/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
7122/// vector<3x2xi32>
7123///
7124/// becomes ->
7125///
7126/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
7127/// vector<3x2xi32>
7128///
7129class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
7130public:
7131 using Base::Base;
7132 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7133 PatternRewriter &rewriter) const override {
7134 auto fromElementsOp =
7135 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7136 if (!fromElementsOp)
7137 return failure();
7138
7139 VectorType srcTy = fromElementsOp.getDest().getType();
7140 VectorType dstTy = transposeOp.getType();
7141
7142 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7143 int64_t rank = srcTy.getRank();
7144
7145 // Build inverse permutation to map destination indices back to source.
7146 SmallVector<int64_t> inversePerm(rank, 0);
7147 for (int64_t i = 0; i < rank; ++i)
7148 inversePerm[permutation[i]] = i;
7149
7150 ArrayRef<int64_t> srcShape = srcTy.getShape();
7151 ArrayRef<int64_t> dstShape = dstTy.getShape();
7152 SmallVector<int64_t> srcIdx(rank, 0);
7153 SmallVector<int64_t> dstIdx(rank, 0);
7154 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
7155 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
7156
7157 auto elementsOld = fromElementsOp.getElements();
7158 SmallVector<Value> elementsNew;
7159 int64_t dstNumElements = dstTy.getNumElements();
7160 elementsNew.reserve(dstNumElements);
7161
7162 // For each element in destination row-major order, pick the corresponding
7163 // source element.
7164 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7165 // Pick the destination element index.
7166 dstIdx = delinearize(linearIdx, dstStrides);
7167 // Map the destination element index to the source element index.
7168 for (int64_t j = 0; j < rank; ++j)
7169 srcIdx[j] = dstIdx[inversePerm[j]];
7170 // Linearize the source element index.
7171 int64_t srcLin = linearize(srcIdx, srcStrides);
7172 // Add the source element to the new elements.
7173 elementsNew.push_back(elementsOld[srcLin]);
7174 }
7175
7176 rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
7177 elementsNew);
7178 return success();
7179 }
7180};
7181
7182/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
7183/// 'order preserving', where 'order preserving' means the flattened
7184/// inputs and outputs of the transpose have identical (numerical) values.
7185///
7186/// Example:
7187/// ```
7188/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
7189/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
7190/// to vector<8x1xi32>
7191/// ```
7192/// can be rewritten as the equivalent
7193/// ```
7194/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
7195/// ```
7196/// The algorithm works by partitioning dimensions into groups that can be
7197/// locally permuted while preserving order, and checks that the transpose
7198/// only permutes within these groups.
7199///
7200/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
7201/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
7202/// broadcasting from 1x1x4x1x1x7.
7203/// ^^^ ^ ^^^ ^
7204/// groups: 0 1 2 3
7205/// Order preserving permutations for this example are ones that only permute
7206/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
7207class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
7208public:
7209 using Base::Base;
7210 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7211 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7212
7213 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7214 PatternRewriter &rewriter) const override {
7215
7216 vector::BroadcastOp broadcast =
7217 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
7218 if (!broadcast) {
7219 return rewriter.notifyMatchFailure(transpose,
7220 "not preceded by a broadcast");
7221 }
7222
7223 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
7224 VectorType outputType = transpose.getResultVectorType();
7225
7226 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
7227 bool inputIsScalar = !inputType;
7228 if (inputIsScalar) {
7229 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7230 broadcast.getSource());
7231 return success();
7232 }
7233
7234 ArrayRef<int64_t> permutation = transpose.getPermutation();
7235 ArrayRef<int64_t> inputShape = inputType.getShape();
7236 int64_t inputRank = inputType.getRank();
7237 int64_t outputRank = transpose.getType().getRank();
7238 int64_t deltaRank = outputRank - inputRank;
7239
7240 int low = 0;
7241 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7242 bool notOne = inputShape[inputIndex] != 1;
7243 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7244 bool groupEndFound = notOne || prevNotOne;
7245 if (groupEndFound) {
7246 int high = inputIndex + deltaRank;
7247 // Return failure if not all permutation destinations for indices in
7248 // [low, high) are in [low, high), i.e. the permutation is not local to
7249 // the group.
7250 for (int i = low; i < high; ++i) {
7251 if (permutation[i] < low || permutation[i] >= high) {
7252 return rewriter.notifyMatchFailure(
7253 transpose, "permutation not local to group");
7254 }
7255 }
7256 low = high;
7257 }
7258 }
7259
7260 // We don't need to check the final group [low, outputRank) because if it is
7261 // not locally bound, there must be a preceding group that already failed
7262 // the check (impossible to have just 1 non-locally bound group).
7263
7264 // The preceding logic also ensures that at this point, the output of the
7265 // transpose is definitely broadcastable from the input shape, assert so:
7266 assert(vector::isBroadcastableTo(inputType, outputType) ==
7267 vector::BroadcastableToResult::Success &&
7268 "not broadcastable directly to transpose output");
7269
7270 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7271 broadcast.getSource());
7272
7273 return success();
7274 }
7275};
7276
7277} // namespace
7278
7279void vector::TransposeOp::getCanonicalizationPatterns(
7280 RewritePatternSet &results, MLIRContext *context) {
7281 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7282 FoldTransposeSplat, FoldTransposeFromElements,
7283 FoldTransposeBroadcast>(context);
7284}
7285
7286//===----------------------------------------------------------------------===//
7287// ConstantMaskOp
7288//===----------------------------------------------------------------------===//
7289
7290void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
7291 VectorType type, ConstantMaskKind kind) {
7292 assert(kind == ConstantMaskKind::AllTrue ||
7293 kind == ConstantMaskKind::AllFalse);
7294 build(builder, result, type,
7295 kind == ConstantMaskKind::AllTrue
7296 ? type.getShape()
7297 : SmallVector<int64_t>(type.getRank(), 0));
7298}
7299
7300LogicalResult ConstantMaskOp::verify() {
7301 auto resultType = llvm::cast<VectorType>(getResult().getType());
7302 // Check the corner case of 0-D vectors first.
7303 if (resultType.getRank() == 0) {
7304 if (getMaskDimSizes().size() != 1)
7305 return emitError("array attr must have length 1 for 0-D vectors");
7306 auto dim = getMaskDimSizes()[0];
7307 if (dim != 0 && dim != 1)
7308 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
7309 return success();
7310 }
7311
7312 // Verify that array attr size matches the rank of the vector result.
7313 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
7314 return emitOpError(
7315 "must specify array attr of size equal vector result rank");
7316 // Verify that each array attr element is in bounds of corresponding vector
7317 // result dimension size.
7318 auto resultShape = resultType.getShape();
7319 auto resultScalableDims = resultType.getScalableDims();
7320 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7321 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7322 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7323 return emitOpError(
7324 "array attr of size out of bounds of vector result dimension size");
7325 if (resultScalableDims[index] && maskDimSize != 0 &&
7326 maskDimSize != resultShape[index])
7327 return emitOpError(
7328 "only supports 'none set' or 'all set' scalable dimensions");
7329 }
7330 // Verify that if one mask dim size is zero, they all should be zero (because
7331 // the mask region is a conjunction of each mask dimension interval).
7332 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7333 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
7334 if (anyZeros && !allZeros)
7335 return emitOpError("expected all mask dim sizes to be zeros, "
7336 "as a result of conjunction with zero mask dim");
7337 return success();
7338}
7339
7340bool ConstantMaskOp::isAllOnesMask() {
7341 auto resultType = getVectorType();
7342 // Check the corner case of 0-D vectors first.
7343 if (resultType.getRank() == 0) {
7344 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7345 return getMaskDimSizes()[0] == 1;
7346 }
7347 for (const auto [resultSize, maskDimSize] :
7348 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7349 if (maskDimSize < resultSize)
7350 return false;
7351 }
7352 return true;
7353}
7354
7355OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7356 ArrayRef<int64_t> bounds = getMaskDimSizes();
7357 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7358
7359 auto createBoolSplat = [&](bool x) {
7360 return SplatElementsAttr::get(getVectorType(),
7362 };
7363
7364 // Check the corner case of 0-D vectors first.
7365 if (vectorSizes.empty()) {
7366 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7367 return createBoolSplat(bounds[0] == 1);
7368 }
7369 // Fold vector.constant_mask to splat if possible.
7370 if (bounds == vectorSizes)
7371 return createBoolSplat(true);
7372 if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7373 return createBoolSplat(false);
7374 return OpFoldResult();
7375}
7376
7377//===----------------------------------------------------------------------===//
7378// CreateMaskOp
7379//===----------------------------------------------------------------------===//
7380
7381void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7382 VectorType type,
7383 ArrayRef<OpFoldResult> mixedOperands) {
7384 SmallVector<Value> operands =
7385 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7386 build(builder, result, type, operands);
7387}
7388
7389LogicalResult CreateMaskOp::verify() {
7390 auto vectorType = llvm::cast<VectorType>(getResult().getType());
7391 // Verify that an operand was specified for each result vector each dimension.
7392 if (vectorType.getRank() == 0) {
7393 if (getNumOperands() != 1)
7394 return emitOpError(
7395 "must specify exactly one operand for 0-D create_mask");
7396 } else if (getNumOperands() !=
7397 llvm::cast<VectorType>(getResult().getType()).getRank()) {
7398 return emitOpError(
7399 "must specify an operand for each result vector dimension");
7400 }
7401 return success();
7402}
7403
7404namespace {
7405
7406/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7407///
7408/// Ex 1:
7409/// %c2 = arith.constant 2 : index
7410/// %c3 = arith.constant 3 : index
7411/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7412/// Becomes:
7413/// vector.constant_mask [3, 2] : vector<4x3xi1>
7414///
7415/// Ex 2:
7416/// %c_neg_1 = arith.constant -1 : index
7417/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7418/// becomes:
7419/// vector.constant_mask [0] : vector<[8]xi1>
7420///
7421/// Ex 3:
7422/// %c8 = arith.constant 8 : index
7423/// %c16 = arith.constant 16 : index
7424/// %0 = vector.vscale
7425/// %1 = arith.muli %0, %c16 : index
7426/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7427/// becomes:
7428/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7429class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7430public:
7431 using Base::Base;
7432
7433 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7434 PatternRewriter &rewriter) const override {
7435 VectorType maskType = createMaskOp.getVectorType();
7436 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7437 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7438
7439 // Special case: Rank zero shape.
7440 constexpr std::array<int64_t, 1> rankZeroShape{1};
7441 constexpr std::array<bool, 1> rankZeroScalableDims{false};
7442 if (maskType.getRank() == 0) {
7443 maskTypeDimSizes = rankZeroShape;
7444 maskTypeDimScalableFlags = rankZeroScalableDims;
7445 }
7446
7447 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7448 // collect the `constantDims` (for the ConstantMaskOp).
7449 SmallVector<int64_t, 4> constantDims;
7450 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7451 if (auto intSize = getConstantIntValue(dimSize)) {
7452 // Constant value.
7453 // If the mask dim is non-scalable this can be any value.
7454 // If the mask dim is scalable only zero (all-false) is supported.
7455 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7456 return failure();
7457 constantDims.push_back(*intSize);
7458 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7459 // Constant vscale multiple (e.g. 4 x vscale).
7460 // Must be all-true to fold to a ConstantMask.
7461 if (vscaleMultiplier < maskTypeDimSizes[i])
7462 return failure();
7463 constantDims.push_back(*vscaleMultiplier);
7464 } else {
7465 return failure();
7466 }
7467 }
7468
7469 // Clamp values to constant_mask bounds.
7470 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7471 value = std::clamp<int64_t>(value, 0, maskDimSize);
7472
7473 // If one of dim sizes is zero, set all dims to zero.
7474 if (llvm::is_contained(constantDims, 0))
7475 constantDims.assign(constantDims.size(), 0);
7476
7477 // Replace 'createMaskOp' with ConstantMaskOp.
7478 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7479 constantDims);
7480 return success();
7481 }
7482};
7483
7484} // namespace
7485
7486void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7487 MLIRContext *context) {
7488 results.add<CreateMaskFolder>(context);
7489}
7490
7491//===----------------------------------------------------------------------===//
7492// MaskOp
7493//===----------------------------------------------------------------------===//
7494
7495void MaskOp::build(
7496 OpBuilder &builder, OperationState &result, Value mask,
7497 Operation *maskableOp,
7498 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7499 assert(maskRegionBuilder &&
7500 "builder callback for 'maskRegion' must be present");
7501
7502 result.addOperands(mask);
7503 OpBuilder::InsertionGuard guard(builder);
7504 Region *maskRegion = result.addRegion();
7505 builder.createBlock(maskRegion);
7506 maskRegionBuilder(builder, maskableOp);
7507}
7508
7509void MaskOp::build(
7510 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7511 Value mask, Operation *maskableOp,
7512 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7513 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7514 maskRegionBuilder);
7515}
7516
7517void MaskOp::build(
7518 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7519 Value mask, Value passthru, Operation *maskableOp,
7520 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7521 build(builder, result, mask, maskableOp, maskRegionBuilder);
7522 if (passthru)
7523 result.addOperands(passthru);
7524 result.addTypes(resultTypes);
7525}
7526
7527ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7528 // Create the op region.
7529 result.regions.reserve(1);
7530 Region &maskRegion = *result.addRegion();
7531
7532 auto &builder = parser.getBuilder();
7533
7534 // Parse all the operands.
7535 OpAsmParser::UnresolvedOperand mask;
7536 if (parser.parseOperand(mask))
7537 return failure();
7538
7539 // Optional passthru operand.
7540 OpAsmParser::UnresolvedOperand passthru;
7541 ParseResult parsePassthru = parser.parseOptionalComma();
7542 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7543 return failure();
7544
7545 // Parse op region.
7546 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7547 return failure();
7548
7549 MaskOp::ensureTerminator(maskRegion, builder, result.location);
7550
7551 // Parse the optional attribute list.
7552 if (parser.parseOptionalAttrDict(result.attributes))
7553 return failure();
7554
7555 // Parse all the types.
7556 Type maskType;
7557 if (parser.parseColonType(maskType))
7558 return failure();
7559
7560 SmallVector<Type> resultTypes;
7561 if (parser.parseOptionalArrowTypeList(resultTypes))
7562 return failure();
7563 result.types.append(resultTypes);
7564
7565 // Resolve operands.
7566 if (parser.resolveOperand(mask, maskType, result.operands))
7567 return failure();
7568
7569 if (parsePassthru.succeeded()) {
7570 if (resultTypes.empty())
7571 return parser.emitError(
7572 parser.getNameLoc(),
7573 "expects a result if passthru operand is provided");
7574
7575 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7576 return failure();
7577 }
7578
7579 return success();
7580}
7581
7582void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7583 p << " " << getMask();
7584 if (getPassthru())
7585 p << ", " << getPassthru();
7586
7587 // Print single masked operation and skip terminator.
7588 p << " { ";
7589 Block *singleBlock = &getMaskRegion().getBlocks().front();
7590 if (singleBlock && !singleBlock->getOperations().empty())
7591 p.printCustomOrGenericOp(&singleBlock->front());
7592 p << " }";
7593
7594 p.printOptionalAttrDict(getOperation()->getAttrs());
7595
7596 p << " : " << getMask().getType();
7597 if (getNumResults() > 0)
7598 p << " -> " << getResultTypes();
7599}
7600
7601void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7602 // 1. For an empty `vector.mask`, create a default terminator.
7603 if (region.empty() || region.front().empty()) {
7604 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7605 MaskOp>::ensureTerminator(region, builder, loc);
7606 return;
7607 }
7608
7609 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7610 Block &block = region.front();
7611 if (isa<vector::YieldOp>(block.back()))
7612 return;
7613
7614 // 3. For a non-empty `vector.mask` without an explicit terminator:
7615
7616 // Create default terminator if the number of masked operations is not
7617 // one. This case will trigger a verification failure.
7618 if (block.getOperations().size() != 1) {
7619 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7620 MaskOp>::ensureTerminator(region, builder, loc);
7621 return;
7622 }
7623
7624 // Create a terminator that yields the results from the masked operation.
7625 OpBuilder opBuilder(builder.getContext());
7626 Operation *maskedOp = &block.front();
7627 opBuilder.setInsertionPointToEnd(&block);
7628 vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7629}
7630
7631LogicalResult MaskOp::verify() {
7632 // Structural checks.
7633 Block &block = getMaskRegion().getBlocks().front();
7634 if (block.getOperations().empty())
7635 return emitOpError("expects a terminator within the mask region");
7636
7637 unsigned numMaskRegionOps = block.getOperations().size();
7638 if (numMaskRegionOps > 2)
7639 return emitOpError("expects only one operation to mask");
7640
7641 // Terminator checks.
7642 auto terminator = dyn_cast<vector::YieldOp>(block.back());
7643 if (!terminator)
7644 return emitOpError("expects a terminator within the mask region");
7645
7646 if (terminator->getNumOperands() != getNumResults())
7647 return emitOpError(
7648 "expects number of results to match mask region yielded values");
7649
7650 // Empty vector.mask. Nothing else to check.
7651 if (numMaskRegionOps == 1)
7652 return success();
7653
7654 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7655 if (!maskableOp)
7656 return emitOpError("expects a MaskableOpInterface within the mask region");
7657
7658 // Result checks.
7659 if (maskableOp->getNumResults() != getNumResults())
7660 return emitOpError("expects number of results to match maskable operation "
7661 "number of results");
7662
7663 if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7664 return emitOpError("expects all the results from the MaskableOpInterface "
7665 "to match all the values returned by the terminator");
7666
7667 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7668 return emitOpError(
7669 "expects result type to match maskable operation result type");
7670
7671 if (llvm::count_if(maskableOp->getResultTypes(),
7672 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7673 return emitOpError("multiple vector results not supported");
7674
7675 // Mask checks.
7676 Type expectedMaskType = maskableOp.getExpectedMaskType();
7677 if (getMask().getType() != expectedMaskType)
7678 return emitOpError("expects a ")
7679 << expectedMaskType << " mask for the maskable operation";
7680
7681 // Passthru checks.
7682 Value passthru = getPassthru();
7683 if (passthru) {
7684 if (!maskableOp.supportsPassthru())
7685 return emitOpError(
7686 "doesn't expect a passthru argument for this maskable operation");
7687
7688 if (maskableOp->getNumResults() != 1)
7689 return emitOpError("expects result when passthru argument is provided");
7690
7691 if (passthru.getType() != maskableOp->getResultTypes()[0])
7692 return emitOpError("expects passthru type to match result type");
7693 }
7694
7695 return success();
7696}
7697
7698/// Folds empty `vector.mask` with no passthru operand and with or without
7699/// return values. For example:
7700///
7701/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7702/// vector<8xi1> -> vector<8xf32>
7703/// %1 = user_op %0 : vector<8xf32>
7704///
7705/// becomes:
7706///
7707/// %0 = user_op %a : vector<8xf32>
7708///
7709/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7710/// as it requires creating new operations.
7711
7712static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7713 SmallVectorImpl<OpFoldResult> &results) {
7714 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7715 return failure();
7716
7717 Block *block = maskOp.getMaskBlock();
7718 auto terminator = cast<vector::YieldOp>(block->front());
7719 if (terminator.getNumOperands() == 0)
7720 return failure();
7721
7722 // `vector.mask` has results, propagate the results.
7723 llvm::append_range(results, terminator.getOperands());
7724 return success();
7725}
7726
7727LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7728 SmallVectorImpl<OpFoldResult> &results) {
7729 if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7730 return success();
7731
7732 MaskFormat maskFormat = getMaskFormat(getMask());
7733 if (maskFormat != MaskFormat::AllTrue)
7734 return failure();
7735
7736 // Move maskable operation outside of the `vector.mask` region.
7737 // If there is no maskable op (empty body), the fold cannot proceed; the
7738 // canonicalizer handles this case instead.
7739 Operation *maskableOp = getMaskableOp();
7740 if (!maskableOp)
7741 return failure();
7742 maskableOp->dropAllUses();
7743 maskableOp->moveBefore(getOperation());
7744
7745 llvm::append_range(results, maskableOp->getResults());
7746 return success();
7747}
7748
7749/// Canonialize empty `vector.mask` operations that can't be handled in
7750/// `VectorMask::fold` as they require creating new operations.
7751///
7752/// Example 1: Empty `vector.mask` with passthru operand.
7753///
7754/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7755/// vector<8xi1> -> vector<8xf32>
7756///
7757/// becomes:
7758///
7759/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7760///
7761class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7762 using Base::Base;
7763
7764 LogicalResult matchAndRewrite(MaskOp maskOp,
7765 PatternRewriter &rewriter) const override {
7766 if (!maskOp.isEmpty())
7767 return failure();
7768
7769 if (!maskOp.hasPassthru())
7770 return failure();
7771
7772 // arith.select with a vector condition requires the value types to be
7773 // vectors of the same shape. Since vector.mask always has a vector mask
7774 // type, bail out when any result type doesn't match the mask shape to
7775 // avoid creating invalid IR.
7776 VectorType maskType = maskOp.getMask().getType();
7777 for (Type resultType : maskOp.getResultTypes()) {
7778 auto vecResultType = dyn_cast<VectorType>(resultType);
7779 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
7780 return failure();
7781 }
7782
7783 Block *block = maskOp.getMaskBlock();
7784 auto terminator = cast<vector::YieldOp>(block->front());
7785 assert(terminator.getNumOperands() == 1 &&
7786 "expected one result when passthru is provided");
7787
7788 rewriter.replaceOpWithNewOp<arith::SelectOp>(
7789 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7790 terminator.getOperand(0), maskOp.getPassthru());
7791
7792 return success();
7793 }
7794};
7795
7796void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7797 MLIRContext *context) {
7798 results.add<CanonializeEmptyMaskOp>(context);
7799}
7800
7801// MaskingOpInterface definitions.
7802
7803/// Returns the operation masked by this 'vector.mask'.
7804Operation *MaskOp::getMaskableOp() {
7805 Block *block = getMaskBlock();
7806 if (block->getOperations().size() < 2)
7807 return nullptr;
7808
7809 return &block->front();
7810}
7811
7812/// Returns true if 'vector.mask' has a passthru value.
7813bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
7814
7815//===----------------------------------------------------------------------===//
7816// ScanOp
7817//===----------------------------------------------------------------------===//
7818
7819LogicalResult ScanOp::verify() {
7820 VectorType srcType = getSourceType();
7821 VectorType initialType = getInitialValueType();
7822 // Check reduction dimension < rank.
7823 int64_t srcRank = srcType.getRank();
7824 int64_t reductionDim = getReductionDim();
7825 if (reductionDim >= srcRank)
7826 return emitOpError("reduction dimension ")
7827 << reductionDim << " has to be less than " << srcRank;
7828
7829 // Check that rank(initial_value) = rank(src) - 1.
7830 int64_t initialValueRank = initialType.getRank();
7831 if (initialValueRank != srcRank - 1)
7832 return emitOpError("initial value rank ")
7833 << initialValueRank << " has to be equal to " << srcRank - 1;
7834
7835 // Check shapes of initial value and src.
7836 ArrayRef<int64_t> srcShape = srcType.getShape();
7837 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7838 SmallVector<int64_t> expectedShape;
7839 for (int i = 0; i < srcRank; i++) {
7840 if (i != reductionDim)
7841 expectedShape.push_back(srcShape[i]);
7842 }
7843 if (!llvm::equal(initialValueShapes, expectedShape)) {
7844 return emitOpError("incompatible input/initial value shapes");
7845 }
7846
7847 // Verify supported reduction kind.
7848 Type eltType = getDestType().getElementType();
7849 if (!isSupportedCombiningKind(getKind(), eltType))
7850 return emitOpError("unsupported reduction type ")
7851 << eltType << " for kind '" << stringifyCombiningKind(getKind())
7852 << "'";
7853
7854 return success();
7855}
7856
7858 RewritePatternSet &patterns, PatternBenefit benefit) {
7859 patterns
7860 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7861 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7862 StridedSliceConstantMaskFolder, TransposeFolder>(
7863 patterns.getContext(), benefit);
7864}
7865
7866Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
7867 CombiningKind kind, Value v1, Value acc,
7868 arith::FastMathFlagsAttr fastmath,
7869 Value mask) {
7870 Type t1 = getElementTypeOrSelf(v1.getType());
7871 Type tAcc = getElementTypeOrSelf(acc.getType());
7872 Value result;
7873
7874 switch (kind) {
7875 case CombiningKind::ADD:
7876 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7877 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
7878 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7879 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7880 else
7881 llvm_unreachable("invalid value types for ADD reduction");
7882 break;
7883 case CombiningKind::AND:
7884 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7885 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
7886 break;
7887 case CombiningKind::MAXNUMF:
7888 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7889 "expected float values");
7890 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7891 break;
7892 case CombiningKind::MAXIMUMF:
7893 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7894 "expected float values");
7895 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7896 break;
7897 case CombiningKind::MINNUMF:
7898 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7899 "expected float values");
7900 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7901 break;
7902 case CombiningKind::MINIMUMF:
7903 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7904 "expected float values");
7905 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7906 break;
7907 case CombiningKind::MAXSI:
7908 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7909 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7910 break;
7911 case CombiningKind::MINSI:
7912 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7913 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7914 break;
7915 case CombiningKind::MAXUI:
7916 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7917 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7918 break;
7919 case CombiningKind::MINUI:
7920 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7921 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7922 break;
7923 case CombiningKind::MUL:
7924 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7925 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
7926 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7927 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7928 else
7929 llvm_unreachable("invalid value types for MUL reduction");
7930 break;
7931 case CombiningKind::OR:
7932 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7933 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
7934 break;
7935 case CombiningKind::XOR:
7936 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7937 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7938 break;
7939 };
7940
7941 assert(result && "unknown CombiningKind");
7942 return selectPassthru(b, mask, result, acc);
7943}
7944
7945//===----------------------------------------------------------------------===//
7946// StepOp
7947//===----------------------------------------------------------------------===//
7948
7949void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7950 SetIntRangeFn setResultRanges) {
7951 auto resultType = cast<VectorType>(getType());
7952 if (resultType.isScalable()) {
7953 return;
7954 }
7955 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7956 APInt zero(bitwidth, 0);
7957 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7958 ConstantIntRanges result = {zero, high, zero, high};
7959 setResultRanges(getResult(), result);
7960}
7961
7962namespace {
7963
7964/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
7965/// constant large enough such that the result is the same at all indices.
7966///
7967/// For example, rewrite the 'greater than' comparison below,
7968///
7969/// ```mlir
7970/// %cst = arith.constant dense<7> : vector<3xindex>
7971/// %stp = vector.step : vector<3xindex>
7972/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
7973/// ```
7974///
7975/// as,
7976///
7977/// ```mlir
7978/// %out = arith.constant dense<false> : vector<3xi1>.
7979/// ```
7980///
7981/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
7982/// is false at ALL indices we fold. If the constant was 1, then
7983/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
7984/// conservatively preferring the 'compact' vector.step representation.
7985///
7986/// Note: this folder only works for the case where the constant (`%cst` above)
7987/// is the second operand of the comparison. The arith.cmpi canonicalizer will
7988/// ensure that constants are always second (on the right).
7989struct StepCompareFolder : public OpRewritePattern<StepOp> {
7990 using Base::Base;
7991
7992 LogicalResult matchAndRewrite(StepOp stepOp,
7993 PatternRewriter &rewriter) const override {
7994 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7995
7996 for (OpOperand &use : stepOp.getResult().getUses()) {
7997 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7998 if (!cmpiOp)
7999 continue;
8000
8001 // arith.cmpi canonicalizer makes constants final operands.
8002 const unsigned stepOperandNumber = use.getOperandNumber();
8003 if (stepOperandNumber != 0)
8004 continue;
8005
8006 // Check that operand 1 is a constant.
8007 unsigned constOperandNumber = 1;
8008 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8009 std::optional<int64_t> maybeConstValue =
8010 getConstantIntValue(otherOperand);
8011 if (!maybeConstValue.has_value())
8012 continue;
8013
8014 int64_t constValue = maybeConstValue.value();
8015 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8016
8017 auto maybeSplat = [&]() -> std::optional<bool> {
8018 // Handle ult (unsigned less than) and uge (unsigned greater equal).
8019 if ((pred == arith::CmpIPredicate::ult ||
8020 pred == arith::CmpIPredicate::uge) &&
8021 stepSize <= constValue)
8022 return pred == arith::CmpIPredicate::ult;
8023
8024 // Handle ule and ugt.
8025 if ((pred == arith::CmpIPredicate::ule ||
8026 pred == arith::CmpIPredicate::ugt) &&
8027 stepSize - 1 <= constValue) {
8028 return pred == arith::CmpIPredicate::ule;
8029 }
8030
8031 // Handle eq and ne.
8032 if ((pred == arith::CmpIPredicate::eq ||
8033 pred == arith::CmpIPredicate::ne) &&
8034 stepSize <= constValue)
8035 return pred == arith::CmpIPredicate::ne;
8036
8037 return std::nullopt;
8038 }();
8039
8040 if (!maybeSplat.has_value())
8041 continue;
8042
8043 rewriter.setInsertionPointAfter(cmpiOp);
8044
8045 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8046 if (!type)
8047 continue;
8048
8049 auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
8050 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8051 type, boolAttr);
8052
8053 rewriter.replaceOp(cmpiOp, splat);
8054 return success();
8055 }
8056
8057 return failure();
8058 }
8059};
8060} // namespace
8061
8062void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8063 MLIRContext *context) {
8064 results.add<StepCompareFolder>(context);
8065}
8066
8067//===----------------------------------------------------------------------===//
8068// Vector Masking Utilities
8069//===----------------------------------------------------------------------===//
8070
8071/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
8072/// as masked operation.
8073void mlir::vector::createMaskOpRegion(OpBuilder &builder,
8074 Operation *maskableOp) {
8075 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
8076 Block *insBlock = builder.getInsertionBlock();
8077 // Create a block and move the op to that block.
8078 insBlock->getOperations().splice(
8079 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
8080 YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
8081}
8082
8083/// Creates a vector.mask operation around a maskable operation. Returns the
8084/// vector.mask operation if the mask provided is valid. Otherwise, returns
8085/// the maskable operation itself.
8086Operation *mlir::vector::maskOperation(OpBuilder &builder,
8087 Operation *maskableOp, Value mask,
8088 Value passthru) {
8089 if (!mask)
8090 return maskableOp;
8091 if (passthru)
8092 return MaskOp::create(builder, maskableOp->getLoc(),
8093 maskableOp->getResultTypes(), mask, passthru,
8094 maskableOp, createMaskOpRegion);
8095 return MaskOp::create(builder, maskableOp->getLoc(),
8096 maskableOp->getResultTypes(), mask, maskableOp,
8098}
8099
8100/// Creates a vector select operation that picks values from `newValue` or
8101/// `passthru` for each result vector lane based on `mask`. This utility is used
8102/// to propagate the pass-thru value of vector.mask or for cases where only the
8103/// pass-thru value propagation is needed. VP intrinsics do not support
8104/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
8105/// usually able to match op + select patterns and fold them into a native
8106/// target instructions.
8107Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
8108 Value newValue, Value passthru) {
8109 if (!mask)
8110 return newValue;
8111
8112 return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
8113 mask, newValue, passthru);
8114}
8115
8116//===----------------------------------------------------------------------===//
8117// TableGen'd op method definitions
8118//===----------------------------------------------------------------------===//
8119
8120#define GET_ATTRDEF_CLASSES
8121#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8122
8123#define GET_OP_CLASSES
8124#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
b getContext())
auto load
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition VectorOps.cpp:72
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
Definition VectorOps.cpp:62
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
#define mul(a, b)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
bool empty()
Definition Block.h:158
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
iterator begin()
Definition Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:379
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:863
void setOperand(unsigned idx, Value value)
Definition Operation.h:380
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
operand_type_range getOperandTypes()
Definition Operation.h:426
result_type_range getResultTypes()
Definition Operation.h:457
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:444
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:433
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isF32() const
Definition Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition VectorOps.h:64
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition VectorOps.h:72
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const