MLIR 23.0.0git
VectorOps.cpp
Go to the documentation of this file.
1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
15
29#include "mlir/IR/AffineExpr.h"
30#include "mlir/IR/AffineMap.h"
31#include "mlir/IR/Builders.h"
35#include "mlir/IR/IRMapping.h"
39#include "mlir/IR/ValueRange.h"
42#include "mlir/Support/LLVM.h"
44#include "llvm/ADT/ArrayRef.h"
45#include "llvm/ADT/Repeated.h"
46#include "llvm/ADT/STLExtras.h"
47#include "llvm/ADT/SmallVector.h"
48#include "llvm/ADT/SmallVectorExtras.h"
49#include "llvm/ADT/StringSet.h"
50#include "llvm/ADT/TypeSwitch.h"
51#include "llvm/Support/Casting.h"
52
53#include <cassert>
54#include <cstdint>
55#include <numeric>
56
57#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
58// Pull in all enum type and utility function definitions.
59#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
60
61using namespace mlir;
62using namespace mlir::vector;
63
64/// Helper enum to classify mask value.
65enum class MaskFormat {
69};
70
71/// Helper method to classify a mask value. Currently, the method
72/// looks "under the hood" of a constant value with dense attributes
73/// and a constant mask operation (since the client may be called at
74/// various stages during progressive lowering).
76 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
77 // Inspect constant dense values. We count up for bits that
78 // are set, count down for bits that are cleared, and bail
79 // when a mix is detected.
80 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
81 int64_t val = 0;
82 for (bool b : denseElts.getValues<bool>())
83 if (b && val >= 0)
84 val++;
85 else if (!b && val <= 0)
86 val--;
87 else
89 if (val > 0)
91 if (val < 0)
93 }
94 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
95 // Inspect constant mask index. If the index exceeds the
96 // dimension size, all bits are set. If the index is zero
97 // or less, no bits are set.
98 ArrayRef<int64_t> masks = m.getMaskDimSizes();
99 auto shape = m.getType().getShape();
100 bool allTrue = true;
101 bool allFalse = true;
102 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
103 if (maskIdx < dimSize)
104 allTrue = false;
105 if (maskIdx > 0)
106 allFalse = false;
107 }
108 if (allTrue)
109 return MaskFormat::AllTrue;
110 if (allFalse)
112 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
113 // Finds all-false create_masks. An all-true create_mask requires all
114 // dims to be constants, so that'll be folded to a constant_mask, then
115 // detected in the constant_mask case.
116 auto maskOperands = m.getOperands();
117 for (Value operand : maskOperands) {
118 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
119 int64_t dimSize =
120 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
121 if (dimSize <= 0)
123 }
124 }
125 return MaskFormat::Unknown;
126 }
127 return MaskFormat::Unknown;
128}
129
130/// Default callback to build a region with a 'vector.yield' terminator with no
131/// arguments.
133 vector::YieldOp::create(builder, loc);
134}
135
136// Helper for verifying combining kinds in contractions and reductions.
137static bool isSupportedCombiningKind(CombiningKind combiningKind,
138 Type elementType) {
139 switch (combiningKind) {
140 case CombiningKind::ADD:
141 case CombiningKind::MUL:
142 return elementType.isIntOrIndexOrFloat();
143 case CombiningKind::MINUI:
144 case CombiningKind::MINSI:
145 case CombiningKind::MAXUI:
146 case CombiningKind::MAXSI:
147 case CombiningKind::AND:
148 case CombiningKind::OR:
149 case CombiningKind::XOR:
150 return elementType.isIntOrIndex();
151 case CombiningKind::MINNUMF:
152 case CombiningKind::MAXNUMF:
153 case CombiningKind::MINIMUMF:
154 case CombiningKind::MAXIMUMF:
155 return llvm::isa<FloatType>(elementType);
156 }
157 return false;
158}
159
160/// Returns the effective rank of the vector to read/write for Xfer Ops
161///
162/// When the element type of the shaped type is _a scalar_, this will simply
163/// return the rank of the vector ( the result for xfer_read or the value to
164/// store for xfer_write).
165///
166/// When the element type of the base shaped type is _a vector_, returns the
167/// difference between the original vector type and the element type of the
168/// shaped type.
169///
170/// EXAMPLE 1 (element type is _a scalar_):
171/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
172/// - shapedType.getElementType() = f32 (rank 0)
173/// - vectorType.getRank() = 2
174/// - Result = 2 - 0 = 2
175///
176/// EXAMPLE 2 (element type is _a vector_):
177/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
178/// - shapedType.getElementType() = vector<20xf32> (rank 1)
179/// - vectorType.getRank() = 1
180/// - Result = 1 - 1 = 0
181///
182/// This is used to determine the number of minor dimensions for identity maps
183/// in vector transfer Ops.
184static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
185 VectorType vectorType) {
186 unsigned elementVectorRank = 0;
187 VectorType elementVectorType =
188 llvm::dyn_cast<VectorType>(shapedType.getElementType());
189 if (elementVectorType)
190 elementVectorRank += elementVectorType.getRank();
191 return vectorType.getRank() - elementVectorRank;
192}
193
195 VectorType vectorType) {
196 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
197 // TODO: replace once we have 0-d vectors.
198 if (shapedType.getRank() == 0 &&
199 vectorType.getShape() == ArrayRef<int64_t>{1})
200 return AffineMap::get(
201 /*numDims=*/0, /*numSymbols=*/0,
202 getAffineConstantExpr(0, shapedType.getContext()));
204 shapedType.getRank(),
205 getEffectiveVectorRankForXferOp(shapedType, vectorType),
206 shapedType.getContext());
207}
208
209/// Check if `write` is of a constant splat and the masked `read` is padded with
210/// the same splat value -- meaning it could be the same value as the initial
211/// constant splat.
212static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
213 vector::TransferReadOp read) {
214 auto readMask = read.getMask();
215 auto writeMask = write.getMask();
216 // Check if the masks are consistent. The splat value could be the same if the
217 // read is masked (and padded with the splat value), and the write is unmasked
218 // or has the same mask. Note this does not allow the case where the write is
219 // masked and the read is unmasked, as then the read could be of more elements
220 // than the write (which may not be the same value).
221 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
222 if (!couldBeSameSplat)
223 return false;
224 // Check for constant splat (as the source of the write).
225 DenseElementsAttr splatAttr;
226 if (!matchPattern(write.getVector(),
227 m_Constant<DenseElementsAttr>(&splatAttr)) ||
228 !splatAttr.isSplat()) {
229 return false;
230 }
231 // The padding of the read and the constant splat value must be the same.
232 Attribute padAttr;
233 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
234 return false;
235 return padAttr == splatAttr.getSplatValue<Attribute>();
236}
237
238bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
239 vector::TransferReadOp read) {
240 return !defWrite.hasOutOfBoundsDim() &&
241 defWrite.getIndices() == read.getIndices() &&
242 defWrite.getVectorType() == read.getVectorType() &&
243 defWrite.getPermutationMap() == read.getPermutationMap() &&
244 ((!defWrite.getMask() && !read.getMask()) ||
246}
247
248bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
249 vector::TransferWriteOp priorWrite) {
250 return priorWrite.getIndices() == write.getIndices() &&
251 priorWrite.getMask() == write.getMask() &&
252 priorWrite.getVectorType() == write.getVectorType() &&
253 priorWrite.getPermutationMap() == write.getPermutationMap();
254}
255
257 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
258 bool testDynamicValueUsingBounds) {
259 // For simplicity only look at transfer of same type.
260 if (transferA.getVectorType() != transferB.getVectorType())
261 return false;
262 unsigned rankOffset = transferA.getLeadingShapedRank();
263 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
264 Value indexA = transferA.getIndices()[i];
265 Value indexB = transferB.getIndices()[i];
266 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
267 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
268
269 if (i < rankOffset) {
270 // For leading dimensions, if we can prove that index are different we
271 // know we are accessing disjoint slices.
272 if (cstIndexA.has_value() && cstIndexB.has_value()) {
273 if (*cstIndexA != *cstIndexB)
274 return true;
275 continue;
276 }
277 if (testDynamicValueUsingBounds) {
278 // First try to see if we can fully compose and simplify the affine
279 // expression as a fast track.
280 FailureOr<uint64_t> delta =
282 if (succeeded(delta) && *delta != 0)
283 return true;
284
285 FailureOr<bool> testEqual =
287 if (succeeded(testEqual) && !testEqual.value())
288 return true;
289 }
290 } else {
291 // For this dimension, we slice a part of the memref we need to make sure
292 // the intervals accessed don't overlap.
293 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
294 if (cstIndexA.has_value() && cstIndexB.has_value()) {
295 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
296 if (distance >= vectorDim)
297 return true;
298 continue;
299 }
300 if (testDynamicValueUsingBounds) {
301 // First try to see if we can fully compose and simplify the affine
302 // expression as a fast track.
303 FailureOr<int64_t> delta =
305 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
306 return true;
307
308 FailureOr<int64_t> computeDelta =
310 if (succeeded(computeDelta)) {
311 if (std::abs(computeDelta.value()) >= vectorDim)
312 return true;
313 }
314 }
315 }
316 }
317 return false;
318}
319
320bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
321 VectorTransferOpInterface transferB,
322 bool testDynamicValueUsingBounds) {
323 if (transferA.getBase() != transferB.getBase())
324 return false;
325 return isDisjointTransferIndices(transferA, transferB,
326 testDynamicValueUsingBounds);
327}
328
329// Helper to iterate over n-D vector slice elements. Calculate the next
330// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
331// Modifies the `position` in place. Returns a failure when `position` becomes
332// the end position.
333static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
335 ArrayRef<int64_t> offsets) {
336 for (auto [posInDim, dimSize, offsetInDim] :
337 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
338 ++posInDim;
339 if (posInDim < dimSize + offsetInDim)
340 return success();
341
342 // Carry the overflow to the next loop iteration.
343 posInDim = offsetInDim;
344 }
345
346 return failure();
347}
348
349/// Returns the integer numbers in `values`. `values` are expected to be
350/// constant operations.
353 llvm::transform(values, std::back_inserter(ints), [](Value value) {
354 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
355 assert(constOp && "Unexpected non-constant index");
356 return constOp.value();
357 });
358 return ints;
359}
360
361/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
362/// be constant operations.
365 llvm::transform(
366 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
367 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
368 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
369 });
370 return ints;
371}
372
373/// Convert `foldResults` into Values. Integer attributes are converted to
374/// constant op.
376 ArrayRef<OpFoldResult> foldResults) {
377 SmallVector<Value> values;
378 llvm::transform(foldResults, std::back_inserter(values),
379 [&](OpFoldResult foldResult) {
380 if (auto attr = dyn_cast<Attribute>(foldResult))
382 builder, loc, cast<IntegerAttr>(attr).getInt())
383 .getResult();
384
385 return cast<Value>(foldResult);
386 });
387 return values;
388}
389
390std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
391 if (value.getDefiningOp<vector::VectorScaleOp>())
392 return 1;
393 auto mul = value.getDefiningOp<arith::MulIOp>();
394 if (!mul)
395 return {};
396 auto lhs = mul.getLhs();
397 auto rhs = mul.getRhs();
398 if (lhs.getDefiningOp<vector::VectorScaleOp>())
399 return getConstantIntValue(rhs);
400 if (rhs.getDefiningOp<vector::VectorScaleOp>())
401 return getConstantIntValue(lhs);
402 return {};
403}
404
405/// Converts numeric attributes to the expected type. Supports
406/// integer-to-integer and float-to-integer conversions. Returns the original
407/// attribute if no conversion is needed or supported.
408static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
409 // Integer-to-integer conversion
410 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
411 if (auto intType = dyn_cast<IntegerType>(expectedType)) {
412 if (intAttr.getType() != expectedType)
413 return IntegerAttr::get(expectedType, intAttr.getInt());
414 }
415 return attr;
416 }
417
418 // Float-to-integer bitcast (preserves bit representation)
419 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
420 auto intType = dyn_cast<IntegerType>(expectedType);
421 if (!intType)
422 return attr;
423
424 APFloat floatVal = floatAttr.getValue();
425 APInt intVal = floatVal.bitcastToAPInt();
426 return IntegerAttr::get(expectedType, intVal);
427 }
428
429 return attr;
430}
431
432//===----------------------------------------------------------------------===//
433// CombiningKindAttr
434//===----------------------------------------------------------------------===//
435
436namespace mlir {
437namespace vector {
438namespace detail {
440 using KeyTy = uint64_t;
441
443
444 bool operator==(const KeyTy &key) const { return value == key; }
445
447 const KeyTy &key) {
448 return new (allocator.allocate<BitmaskEnumStorage>())
450 }
451
453};
454} // namespace detail
455} // namespace vector
456} // namespace mlir
457
458//===----------------------------------------------------------------------===//
459// VectorDialect
460//===----------------------------------------------------------------------===//
461
462namespace {
463/// This class defines the interface for handling inlining with vector dialect
464/// operations.
465struct VectorInlinerInterface : public DialectInlinerInterface {
466 using DialectInlinerInterface::DialectInlinerInterface;
467
468 /// All vector dialect ops can be inlined.
469 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
470 return true;
471 }
472};
473} // namespace
474
475void VectorDialect::initialize() {
476 addAttributes<
477#define GET_ATTRDEF_LIST
478#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
479 >();
480
481 addOperations<
482#define GET_OP_LIST
483#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
484 >();
485
486 addInterfaces<VectorInlinerInterface>();
487
488 declarePromisedInterfaces<memref::IndexedAccessOpInterface, LoadOp, StoreOp,
489 MaskedLoadOp, MaskedStoreOp, ExpandLoadOp,
490 CompressStoreOp>();
491 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
492 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
493 YieldOp>();
494 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
495 TransferWriteOp>();
496 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
497 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
498 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
499}
500
501/// Materialize a single constant operation from a given attribute value with
502/// the desired resultant type.
503Operation *VectorDialect::materializeConstant(OpBuilder &builder,
504 Attribute value, Type type,
505 Location loc) {
506 if (matchPattern(value, ub::m_Poison()))
507 return value.getDialect().materializeConstant(builder, value, type, loc);
508
509 return arith::ConstantOp::materialize(builder, value, type, loc);
510}
511
513 return builder.getIntegerType(64);
514}
515
517 ArrayRef<int64_t> values) {
518 return builder.getI64ArrayAttr(values);
519}
520
521//===----------------------------------------------------------------------===//
522// MultiDimReductionOp
523//===----------------------------------------------------------------------===//
524
525void vector::MultiDimReductionOp::build(OpBuilder &builder,
526 OperationState &result, Value source,
527 Value acc, ArrayRef<bool> reductionMask,
528 CombiningKind kind) {
529 SmallVector<int64_t> reductionDims;
530 for (const auto &en : llvm::enumerate(reductionMask))
531 if (en.value())
532 reductionDims.push_back(en.index());
533 build(builder, result, kind, source, acc, reductionDims);
534}
535
536OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
537 // No reduction dims: this is a noop regardless of rank.
538 if (getReductionDims().empty())
539 return getSource();
540 return {};
541}
542
543std::optional<SmallVector<int64_t, 4>>
544MultiDimReductionOp::getShapeForUnroll() {
545 return llvm::to_vector<4>(getSourceVectorType().getShape());
546}
547
548LogicalResult MultiDimReductionOp::verify() {
549 SmallVector<int64_t> targetShape;
550 SmallVector<bool> scalableDims;
551 Type inferredReturnType;
552 auto sourceScalableDims = getSourceVectorType().getScalableDims();
553 for (auto [dimIdx, dimSize] :
554 llvm::enumerate(getSourceVectorType().getShape()))
555 if (!llvm::any_of(getReductionDims(),
556 [dimIdx = dimIdx](int64_t reductionDimIdx) {
557 return reductionDimIdx == static_cast<int64_t>(dimIdx);
558 })) {
559 targetShape.push_back(dimSize);
560 scalableDims.push_back(sourceScalableDims[dimIdx]);
561 }
562 // TODO: update to also allow 0-d vectors when available.
563 if (targetShape.empty())
564 inferredReturnType = getSourceVectorType().getElementType();
565 else
566 inferredReturnType = VectorType::get(
567 targetShape, getSourceVectorType().getElementType(), scalableDims);
568 if (getType() != inferredReturnType)
569 return emitOpError() << "destination type " << getType()
570 << " is incompatible with source type "
571 << getSourceVectorType();
572
573 return success();
574}
575
576/// Returns the mask type expected by this operation.
577Type MultiDimReductionOp::getExpectedMaskType() {
578 auto vecType = getSourceVectorType();
579 return VectorType::get(vecType.getShape(),
580 IntegerType::get(vecType.getContext(), /*width=*/1),
581 vecType.getScalableDims());
582}
583
584namespace {
585// Only unit dimensions that are being reduced are folded. If the dimension is
586// unit, but not reduced, it is not folded, thereby keeping the output type the
587// same. If not all dimensions which are reduced are of unit dimension, this
588// transformation does nothing. This is just a generalization of
589// ElideSingleElementReduction for ReduceOp.
590struct ElideUnitDimsInMultiDimReduction
591 : public OpRewritePattern<MultiDimReductionOp> {
592 using Base::Base;
593
594 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
595 PatternRewriter &rewriter) const override {
596 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
597 for (const auto &dim : enumerate(shape)) {
598 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
599 return failure();
600 }
601
602 // Vector mask setup.
603 OpBuilder::InsertionGuard guard(rewriter);
604 Operation *rootOp;
605 Value mask;
606 if (reductionOp.isMasked()) {
607 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
608 rootOp = reductionOp.getMaskingOp();
609 mask = reductionOp.getMaskingOp().getMask();
610 } else {
611 rootOp = reductionOp;
612 }
613
614 Location loc = reductionOp.getLoc();
615 Value acc = reductionOp.getAcc();
616 Value cast;
617 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
618 if (mask) {
619 VectorType newMaskType =
620 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
621 dstVecType.getScalableDims());
622 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
623 }
624 cast = vector::ShapeCastOp::create(
625 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
626 } else {
627 // This means we are reducing all the dimensions, and all reduction
628 // dimensions are of size 1. So a simple extraction would do.
629 if (mask)
630 mask = vector::ExtractOp::create(rewriter, loc, mask);
631 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
632 }
633
634 Value result =
635 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
636 cast, /*fastmath=*/nullptr, mask);
637 rewriter.replaceOp(rootOp, result);
638 return success();
639 }
640};
641} // namespace
642
643void MultiDimReductionOp::getCanonicalizationPatterns(
644 RewritePatternSet &results, MLIRContext *context) {
645 results.add<ElideUnitDimsInMultiDimReduction>(context);
646}
647
648//===----------------------------------------------------------------------===//
649// ReductionOp
650//===----------------------------------------------------------------------===//
651
652void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
653 CombiningKind kind, Value vector,
654 arith::FastMathFlags fastMathFlags) {
655 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
656}
657
658void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
659 CombiningKind kind, Value vector, Value acc,
660 arith::FastMathFlags fastMathFlags) {
661 build(builder, result,
662 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
663 acc, fastMathFlags);
664}
665
666LogicalResult ReductionOp::verify() {
667 // Verify for 0-D and 1-D vector.
668 int64_t rank = getSourceVectorType().getRank();
669 if (rank > 1)
670 return emitOpError("unsupported reduction rank: ") << rank;
671
672 // Verify supported reduction kind.
673 Type eltType = getDest().getType();
674 if (!isSupportedCombiningKind(getKind(), eltType))
675 return emitOpError("unsupported reduction type '")
676 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
677 << "'";
678
679 return success();
680}
681
682// MaskableOpInterface methods.
683
684/// Returns the mask type expected by this operation.
685Type ReductionOp::getExpectedMaskType() {
686 auto vecType = getSourceVectorType();
687 return VectorType::get(vecType.getShape(),
688 IntegerType::get(vecType.getContext(), /*width=*/1),
689 vecType.getScalableDims());
690}
691
693 OpBuilder &builder, Location loc,
694 Value vector) {
695 switch (op) {
696 case arith::AtomicRMWKind::addf:
697 case arith::AtomicRMWKind::addi:
698 return vector::ReductionOp::create(builder, vector.getLoc(),
699 CombiningKind::ADD, vector);
700 case arith::AtomicRMWKind::mulf:
701 case arith::AtomicRMWKind::muli:
702 return vector::ReductionOp::create(builder, vector.getLoc(),
703 CombiningKind::MUL, vector);
704 case arith::AtomicRMWKind::minimumf:
705 return vector::ReductionOp::create(builder, vector.getLoc(),
706 CombiningKind::MINIMUMF, vector);
707 case arith::AtomicRMWKind::mins:
708 return vector::ReductionOp::create(builder, vector.getLoc(),
709 CombiningKind::MINSI, vector);
710 case arith::AtomicRMWKind::minu:
711 return vector::ReductionOp::create(builder, vector.getLoc(),
712 CombiningKind::MINUI, vector);
713 case arith::AtomicRMWKind::maximumf:
714 return vector::ReductionOp::create(builder, vector.getLoc(),
715 CombiningKind::MAXIMUMF, vector);
716 case arith::AtomicRMWKind::maxs:
717 return vector::ReductionOp::create(builder, vector.getLoc(),
718 CombiningKind::MAXSI, vector);
719 case arith::AtomicRMWKind::maxu:
720 return vector::ReductionOp::create(builder, vector.getLoc(),
721 CombiningKind::MAXUI, vector);
722 case arith::AtomicRMWKind::andi:
723 return vector::ReductionOp::create(builder, vector.getLoc(),
724 CombiningKind::AND, vector);
725 case arith::AtomicRMWKind::ori:
726 return vector::ReductionOp::create(builder, vector.getLoc(),
727 CombiningKind::OR, vector);
728 case arith::AtomicRMWKind::minnumf:
729 return vector::ReductionOp::create(builder, vector.getLoc(),
730 CombiningKind::MINNUMF, vector);
731 case arith::AtomicRMWKind::maxnumf:
732 return vector::ReductionOp::create(builder, vector.getLoc(),
733 CombiningKind::MAXNUMF, vector);
734 case arith::AtomicRMWKind::xori:
735 return vector::ReductionOp::create(builder, vector.getLoc(),
736 CombiningKind::XOR, vector);
737 default:
738 (void)emitOptionalError(loc, "Reduction operation type not supported");
739 break;
740 }
741 return nullptr;
742}
743
744std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
745 return llvm::to_vector<4>(getSourceVectorType().getShape());
746}
747
748namespace {
749struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
750 using Base::Base;
751
752 LogicalResult matchAndRewrite(ReductionOp reductionOp,
753 PatternRewriter &rewriter) const override {
754 // Vector mask setup.
755 OpBuilder::InsertionGuard guard(rewriter);
756 auto maskableOp =
757 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
758 Operation *rootOp;
759 Value mask;
760 if (maskableOp.isMasked()) {
761 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
762 rootOp = maskableOp.getMaskingOp();
763 mask = maskableOp.getMaskingOp().getMask();
764 } else {
765 rootOp = reductionOp;
766 }
767
768 auto vectorType = reductionOp.getSourceVectorType();
769 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
770 return failure();
771
772 Location loc = reductionOp.getLoc();
773 if (mask)
774 mask = ExtractOp::create(rewriter, loc, mask);
775 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
776
777 if (Value acc = reductionOp.getAcc())
778 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
779 result, acc,
780 reductionOp.getFastmathAttr(), mask);
781
782 rewriter.replaceOp(rootOp, result);
783 return success();
784 }
785};
786} // namespace
787
788void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
789 MLIRContext *context) {
790 results.add<ElideSingleElementReduction>(context);
791}
792
793//===----------------------------------------------------------------------===//
794// ContractionOp
795//===----------------------------------------------------------------------===//
796
797void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
799 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
800 ArrayRef<IteratorType> iteratorTypes) {
801 result.addOperands({lhs, rhs, acc});
802 result.addTypes(acc.getType());
803 result.addAttribute(
804 getIndexingMapsAttrName(result.name),
805 builder.getAffineMapArrayAttr(
806 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
807 result.addAttribute(
808 getIteratorTypesAttrName(result.name),
809 builder.getArrayAttr(llvm::map_to_vector(
810 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
811 return IteratorTypeAttr::get(builder.getContext(), t);
812 })));
813}
814
815void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
817 ArrayAttr indexingMaps,
818 ArrayAttr iteratorTypes) {
819 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
820 ContractionOp::getDefaultKind());
821}
822
823void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
825 ArrayAttr indexingMaps,
826 ArrayAttr iteratorTypes, CombiningKind kind,
827 arith::FastMathFlags fastMathFlags) {
828 result.addOperands({lhs, rhs, acc});
829 result.addTypes(acc.getType());
830 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
831 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
832 result.addAttribute(getKindAttrName(result.name),
833 CombiningKindAttr::get(builder.getContext(), kind));
834 if (fastMathFlags != arith::FastMathFlags::none)
835 result.addAttribute(
836 getFastmathAttrName(result.name),
837 arith::FastMathFlagsAttr::get(builder.getContext(), fastMathFlags));
838}
839
840ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
846 Type resultType;
847 auto loc = parser.getCurrentLocation();
848 DictionaryAttr dictAttr;
849 // TODO: Unify linalg op attribute parsing.
850 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
851 parser.parseComma() || parser.parseOperand(rhsInfo) ||
852 parser.parseComma() || parser.parseOperand(accInfo) ||
853 parser.parseTrailingOperandList(masksInfo) ||
854 parser.parseOptionalAttrDict(result.attributes) ||
855 parser.parseColonTypeList(types) ||
856 parser.parseKeywordType("into", resultType) ||
857 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
858 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
859 parser.resolveOperand(accInfo, resultType, result.operands) ||
860 parser.addTypeToList(resultType, result.types))
861 return failure();
862 result.attributes.append(dictAttr.getValue().begin(),
863 dictAttr.getValue().end());
864
865 // Convert array of string into an array of IteratyType enums. This is needed,
866 // because tests still use the old format when 'iterator_types' attribute is
867 // represented as an array of strings.
868 // TODO: Remove this conversion once tests are fixed.
869 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
870 result.attributes.get(getIteratorTypesAttrName(result.name)));
871 if (!iteratorTypes) {
872 return parser.emitError(loc)
873 << "expected " << getIteratorTypesAttrName(result.name)
874 << " array attribute";
875 }
876
877 SmallVector<Attribute> iteratorTypeAttrs;
878
879 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
880 auto maybeIteratorType = symbolizeIteratorType(s);
881 if (!maybeIteratorType.has_value())
882 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
883
884 iteratorTypeAttrs.push_back(
885 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
886 }
887 result.attributes.set(getIteratorTypesAttrName(result.name),
888 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
889
890 if (!result.attributes.get(getKindAttrName(result.name))) {
891 result.addAttribute(
892 getKindAttrName(result.name),
893 CombiningKindAttr::get(result.getContext(),
894 ContractionOp::getDefaultKind()));
895 }
896 if (masksInfo.empty())
897 return success();
898 if (masksInfo.size() != 2)
899 return parser.emitError(parser.getNameLoc(),
900 "expected zero or exactly 2 vector mask operands");
901 auto lhsType = llvm::cast<VectorType>(types[0]);
902 auto rhsType = llvm::cast<VectorType>(types[1]);
903 auto maskElementType = parser.getBuilder().getI1Type();
904 std::array<VectorType, 2> maskTypes = {
905 VectorType::Builder(lhsType).setElementType(maskElementType),
906 VectorType::Builder(rhsType).setElementType(maskElementType)};
907 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
908 return failure();
909 return success();
910}
911
912void ContractionOp::print(OpAsmPrinter &p) {
913 // TODO: Unify printing code with linalg ops.
914 auto attrNames = getTraitAttrNames();
915 llvm::StringSet<> traitAttrsSet;
916 traitAttrsSet.insert_range(attrNames);
918 for (auto attr : (*this)->getAttrs()) {
919 if (attr.getName() == getIteratorTypesAttrName()) {
920 auto iteratorTypes =
921 llvm::cast<ArrayAttr>(attr.getValue())
922 .getAsValueRange<IteratorTypeAttr, IteratorType>();
923 // Convert IteratorType enums into the string representation. This is
924 // needed, because tests still use the old format when 'iterator_types'
925 // attribute is represented as an array of strings.
926 // TODO: Remove this conversion once tests are fixed.
927 SmallVector<Attribute> iteratorTypeNames =
928 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) -> Attribute {
929 return StringAttr::get(getContext(), stringifyIteratorType(t));
930 });
931
932 attrs.emplace_back(getIteratorTypesAttrName(),
933 ArrayAttr::get(getContext(), iteratorTypeNames));
934 } else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
935 // Omit fastmath when it equals the default (none) to keep output clean.
936 if (attr.getName() == getFastmathAttrName() &&
937 llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
938 arith::FastMathFlags::none)
939 continue;
940 attrs.push_back(attr);
941 }
942 }
943
944 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
945 p << " " << dictAttr << " " << getLhs() << ", ";
946 p << getRhs() << ", " << getAcc();
947
948 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
949 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
950 << getResultType();
951}
952
953static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
954 const std::vector<std::pair<int64_t, int64_t>> &map) {
955 for (auto &dimPair : map) {
956 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
957 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
958 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
959 return false;
960 }
961 return true;
962}
963
964static LogicalResult verifyOutputShape(
965 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
966 Type resType,
967 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
968 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
969 DenseSet<int64_t> lhsContractingDimSet;
970 DenseSet<int64_t> rhsContractingDimSet;
971 for (auto &dimPair : contractingDimMap) {
972 lhsContractingDimSet.insert(dimPair.first);
973 rhsContractingDimSet.insert(dimPair.second);
974 }
975 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
976 llvm::make_second_range(batchDimMap));
977
978 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
979 SmallVector<int64_t, 4> expectedResultDims;
980 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
981 if (lhsContractingDimSet.count(i) > 0)
982 continue;
983 expectedResultDims.push_back(lhsType.getDimSize(i));
984 }
985
986 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
987 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
988 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
989 continue;
990 expectedResultDims.push_back(rhsType.getDimSize(i));
991 }
992
993 // Verify 'expectedResultDims'.
994 if (expectedResultDims.empty()) {
995 // No batch or free dimension implies a scalar result.
996 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
997 return op.emitOpError("invalid accumulator/result vector shape");
998 } else {
999 // At least one batch or free dimension implies a vector result.
1000 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
1001 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
1002 if (!resVectorType || !accVectorType)
1003 return op.emitOpError("invalid accumulator/result vector shape");
1004
1005 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
1006 // types fully define the result vector type. This assumes the affine maps
1007 // are well-formed, which must have been verified already.
1008 MLIRContext *ctx = op.getContext();
1009 AffineMap lhsMap = op.getIndexingMapsArray()[0];
1010 AffineMap rhsMap = op.getIndexingMapsArray()[1];
1011 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
1012 return op.emitOpError(
1013 "expected all dimensions to be either a LHS or a RHS dimension");
1014 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
1015 for (auto pair :
1016 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1017 VectorType v = pair.first;
1018 auto map = pair.second;
1019 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1020 unsigned pos = map.getDimPosition(idx);
1021 if (!extents[pos])
1022 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
1023 }
1024 }
1025 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
1026 return op.emitOpError("expected all dimensions to get an extent as "
1027 "either a LHS or a RHS dimension");
1028
1029 AffineMap resMap = op.getIndexingMapsArray()[2];
1030 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1031 /*symbolCount=*/0, extents, ctx);
1032 // Compose the resMap with the extentsMap, which is a constant map.
1033 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1034 assert(llvm::all_of(expectedMap.getResults(),
1035 llvm::IsaPred<AffineConstantExpr>) &&
1036 "expected constant extent along all dimensions.");
1037 // Extract the expected shape and build the type.
1038 auto expectedShape =
1039 llvm::map_to_vector<4>(expectedMap.getResults(), [](AffineExpr e) {
1040 return cast<AffineConstantExpr>(e).getValue();
1041 });
1042 auto expected =
1043 VectorType::get(expectedShape, resVectorType.getElementType(),
1044 resVectorType.getScalableDims());
1045 if (resVectorType != expected || accVectorType != expected)
1046 return op.emitOpError(
1047 "invalid accumulator/result vector shape, expected: ")
1048 << expected;
1049 }
1050 return success();
1051}
1052
1053LogicalResult ContractionOp::verify() {
1054 VectorType lhsType = getLhsType();
1055 VectorType rhsType = getRhsType();
1056 Type accType = getAccType();
1057 Type resType = getResultType();
1058
1059 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1060 if (!lhsType.getElementType().isSignlessInteger())
1061 return emitOpError("only supports signless integer types");
1062 }
1063
1064 // Verify that an indexing map was specified for each vector operand.
1065 if (getIndexingMapsArray().size() != 3)
1066 return emitOpError("expected an indexing map for each vector operand");
1067
1068 // Verify that each index map has 'numIterators' inputs, no symbols, and
1069 // that the number of map outputs equals the rank of its associated
1070 // vector operand.
1071 unsigned numIterators = getIteratorTypes().getValue().size();
1072 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1073 auto index = it.index();
1074 auto map = it.value();
1075 if (map.getNumSymbols() != 0)
1076 return emitOpError("expected indexing map ")
1077 << index << " to have no symbols";
1078 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1079 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1080 // Verify that the map has the right number of inputs, outputs, and indices.
1081 // This also correctly accounts for (..) -> () for rank-0 results.
1082 if (map.getNumDims() != numIterators)
1083 return emitOpError("expected indexing map ")
1084 << index << " to have " << numIterators << " number of inputs";
1085 if (map.getNumResults() != rank)
1086 return emitOpError("expected indexing map ")
1087 << index << " to have " << rank << " number of outputs";
1088 if (!map.isProjectedPermutation())
1089 return emitOpError("expected indexing map ")
1090 << index << " to be a projected permutation of its inputs";
1091 }
1092
1093 auto contractingDimMap = getContractingDimMap();
1094 auto batchDimMap = getBatchDimMap();
1095
1096 // Verify at least one contracting dimension pair was specified.
1097 if (contractingDimMap.empty())
1098 return emitOpError("expected at least one contracting dimension pair");
1099
1100 // Verify contracting dimension map was properly constructed.
1101 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1102 return emitOpError("invalid contracting dimension map");
1103
1104 // Verify batch dimension map was properly constructed.
1105 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1106 return emitOpError("invalid batch dimension map");
1107
1108 // Verify 'accType' and 'resType' shape.
1109 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1110 contractingDimMap, batchDimMap)))
1111 return failure();
1112
1113 if (!getKindAttr()) {
1114 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
1115 "'vector.kind<add>')");
1116 }
1117
1118 // Verify supported combining kind.
1119 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1120 auto elementType = vectorType ? vectorType.getElementType() : resType;
1121 if (!isSupportedCombiningKind(getKind(), elementType))
1122 return emitOpError("unsupported contraction type");
1123
1124 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1125 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1126}
1127
1128// MaskableOpInterface methods.
1129
1130/// Returns the mask type expected by this operation. Mostly used for
1131/// verification purposes. It requires the operation to be vectorized."
1132Type ContractionOp::getExpectedMaskType() {
1133 auto indexingMaps = this->getIndexingMapsArray();
1134 AffineMap lhsIdxMap = indexingMaps[0];
1135 AffineMap rhsIdxMap = indexingMaps[1];
1136 VectorType lhsType = this->getLhsType();
1137 VectorType rhsType = this->getRhsType();
1138
1139 unsigned numVecDims = lhsIdxMap.getNumDims();
1140 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1141 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1142
1143 // Using the information in the indexing maps, extract the size of each
1144 // dimension in the vector.contract operation from the two input operands.
1145 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1146 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1147 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1148 lhsType.getScalableDims()[dimIdx];
1149 }
1150 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1151 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1152 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1153 rhsType.getScalableDims()[dimIdx];
1154 }
1155
1156 assert(ShapedType::isStaticShape(maskShape) &&
1157 "Mask shape couldn't be computed");
1158
1159 return VectorType::get(maskShape,
1160 IntegerType::get(lhsType.getContext(), /*width=*/1),
1161 maskShapeScalableDims);
1162}
1163
1164SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1165 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1166 getIteratorTypesAttrName(), getKindAttrName(),
1167 getFastmathAttrName()};
1168}
1169
1171 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1172 if (targetExpr == map.getResult(i))
1173 return i;
1174 return -1;
1175}
1176
1177static std::vector<std::pair<int64_t, int64_t>>
1178getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1179 IteratorType targetIteratorType, MLIRContext *context) {
1180 std::vector<std::pair<int64_t, int64_t>> dimMap;
1181 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1182 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1183 if (iteratorType != targetIteratorType)
1184 continue;
1185 // Search lhs/rhs map results for 'targetExpr'.
1186 auto targetExpr = getAffineDimExpr(it.index(), context);
1187 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1188 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1189 if (lhsDim >= 0 && rhsDim >= 0)
1190 dimMap.emplace_back(lhsDim, rhsDim);
1191 }
1192 return dimMap;
1193}
1194
1195void ContractionOp::getIterationBounds(
1196 SmallVectorImpl<int64_t> &iterationBounds) {
1197 auto lhsShape = getLhsType().getShape();
1198 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1199 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1200 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1201 // Search lhs/rhs map results for 'targetExpr'.
1202 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1203 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1204 if (iteratorType == IteratorType::reduction) {
1205 // Get reduction dim size from lhs shape (same size in rhsShape).
1206 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1207 assert(lhsDimIndex >= 0);
1208 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1209 continue;
1210 }
1211 // Get parallel dimension size from result shape.
1212 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1213 assert(resDimIndex >= 0);
1214 assert(resVectorType != nullptr);
1215 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1216 }
1217}
1218
1219void ContractionOp::getIterationIndexMap(
1220 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1221 unsigned numMaps = getIndexingMapsArray().size();
1222 iterationIndexMap.resize(numMaps);
1223 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1224 auto index = it.index();
1225 auto map = it.value();
1226 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1227 auto dim = cast<AffineDimExpr>(map.getResult(i));
1228 iterationIndexMap[index][dim.getPosition()] = i;
1229 }
1230 }
1231}
1232
1233std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1234 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1235 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1236 getContext());
1237}
1238
1239std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1240 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1241 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1242 getContext());
1243}
1244
1245std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1247 getIterationBounds(shape);
1248 return shape;
1249}
1250
1251/// Return a fused vector::ContractionOp which represents a patterns such as:
1252///
1253/// ```mlir
1254/// %c0 = vector.constant 0: ...
1255/// %c = vector.contract %a, %b, %c0: ...
1256/// %e = add %c, %d: ...
1257/// ```
1258///
1259/// by:
1260///
1261/// ```mlir
1262/// %e = vector.contract %a, %b, %d: ...
1263/// ```
1264///
1265/// Return null if the canonicalization does not apply.
1266// TODO: This should be a folding of Add into Contract in core but while they
1267// live in different dialects, it is not possible without unnatural
1268// dependencies.
1269template <typename AddOpType>
1270struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1271 using OpRewritePattern<AddOpType>::OpRewritePattern;
1272
1273 LogicalResult matchAndRewrite(AddOpType addOp,
1274 PatternRewriter &rewriter) const override {
1275 auto canonicalize = [&](Value maybeContraction,
1276 Value otherOperand) -> vector::ContractionOp {
1277 vector::ContractionOp contractionOp =
1278 dyn_cast_or_null<vector::ContractionOp>(
1279 maybeContraction.getDefiningOp());
1280 if (!contractionOp)
1281 return vector::ContractionOp();
1282 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1283 contractionOp.getAcc().getDefiningOp())) {
1284 if (maybeZero.getValue() ==
1285 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1286 IRMapping bvm;
1287 bvm.map(contractionOp.getAcc(), otherOperand);
1288 auto newContraction =
1289 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1290 rewriter.replaceOp(addOp, newContraction.getResult());
1291 return newContraction;
1292 }
1293 }
1294 return vector::ContractionOp();
1295 };
1296
1297 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1298 vector::ContractionOp contract = canonicalize(a, b);
1299 contract = contract ? contract : canonicalize(b, a);
1300 return contract ? success() : failure();
1301 }
1302};
1303
1304void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1305 MLIRContext *context) {
1308}
1309
1310// Returns `true` if `index` is either within [0, maxIndex) or equal to
1311// `poisonValue`.
1313 int64_t maxIndex) {
1314 return index == poisonValue || (index >= 0 && index < maxIndex);
1315}
1316
1317//===----------------------------------------------------------------------===//
1318// ExtractOp
1319//===----------------------------------------------------------------------===//
1320
1321void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1322 SetIntRangeFn setResultRanges) {
1323 setResultRanges(getResult(), argRanges.front());
1324}
1325
1326void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1327 Value source) {
1328 auto vectorTy = cast<VectorType>(source.getType());
1329 build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1330}
1331
1332void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1333 Value source, int64_t position) {
1334 build(builder, result, source, ArrayRef<int64_t>{position});
1335}
1336
1337void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1338 Value source, OpFoldResult position) {
1339 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1340}
1341
1342void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1343 Value source, ArrayRef<int64_t> position) {
1344 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1345 builder.getDenseI64ArrayAttr(position));
1346}
1347
1348void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1349 Value source, ArrayRef<OpFoldResult> position) {
1350 SmallVector<int64_t> staticPos;
1351 SmallVector<Value> dynamicPos;
1352 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1353 build(builder, result, source, dynamicPos,
1354 builder.getDenseI64ArrayAttr(staticPos));
1355}
1356
1357LogicalResult
1358ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1359 ExtractOp::Adaptor adaptor,
1360 SmallVectorImpl<Type> &inferredReturnTypes) {
1361 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1362 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1363 vectorType.getRank()) {
1364 inferredReturnTypes.push_back(vectorType.getElementType());
1365 } else {
1366 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1367 vectorType.getRank());
1368 inferredReturnTypes.push_back(VectorType::get(
1369 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1370 vectorType.getScalableDims().drop_front(n)));
1371 }
1372 return success();
1373}
1374
1375LogicalResult vector::ExtractOp::verify() {
1376 if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1377 if (resTy.getRank() == 0)
1378 return emitError(
1379 "expected a scalar instead of a 0-d vector as the result type");
1380
1381 // Note: This check must come before getMixedPosition() to prevent a crash.
1382 auto dynamicMarkersCount =
1383 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1384 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1385 return emitOpError(
1386 "mismatch between dynamic and static positions (kDynamic marker but no "
1387 "corresponding dynamic position) -- this can only happen due to an "
1388 "incorrect fold/rewrite");
1389 auto position = getMixedPosition();
1390 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1391 return emitOpError(
1392 "expected position attribute of rank no greater than vector rank");
1393 for (auto [idx, pos] : llvm::enumerate(position)) {
1394 if (auto attr = dyn_cast<Attribute>(pos)) {
1395 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1397 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1398 return emitOpError("expected position attribute #")
1399 << (idx + 1)
1400 << " to be a non-negative integer smaller than the "
1401 "corresponding vector dimension or poison (-1)";
1402 }
1403 }
1404 }
1405 return success();
1406}
1407
1408template <typename IntType>
1410 return llvm::map_to_vector<4>(
1411 arrayAttr.getAsRange<IntegerAttr>(),
1412 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1413}
1414
1415/// Fold the result of chains of ExtractOp in place by simply concatenating the
1416/// positions.
1417static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1418 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1419 return failure();
1420
1421 // TODO: Canonicalization for dynamic position not implemented yet.
1422 if (extractOp.hasDynamicPosition())
1423 return failure();
1424
1425 SmallVector<int64_t> globalPosition;
1426 ExtractOp currentOp = extractOp;
1427 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1428 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1429 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1430 currentOp = nextOp;
1431 // TODO: Canonicalization for dynamic position not implemented yet.
1432 if (currentOp.hasDynamicPosition())
1433 return failure();
1434 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1435 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1436 }
1437 extractOp.setOperand(0, currentOp.getSource());
1438 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1439 OpBuilder b(extractOp.getContext());
1440 std::reverse(globalPosition.begin(), globalPosition.end());
1441 extractOp.setStaticPosition(globalPosition);
1442 return success();
1443}
1444
1445namespace {
1446/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1447/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1448/// Compose TransposeOp permutations as we walk back.
1449/// This helper class keeps an updated extraction position `extractPosition`
1450/// with extra trailing sentinels.
1451/// The sentinels encode the internal transposition status of the result vector.
1452/// As we iterate, extractPosition is permuted and updated.
1453class ExtractFromInsertTransposeChainState {
1454public:
1455 ExtractFromInsertTransposeChainState(ExtractOp e);
1456
1457 /// Iterate over producing insert and transpose ops until we find a fold.
1458 Value fold();
1459
1460private:
1461 /// Return true if the vector at position `a` is contained within the vector
1462 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1463 /// is a prefix of `b`.
1464 template <typename ContainerA, typename ContainerB>
1465 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1466 return a.size() <= b.size() &&
1467 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1468 }
1469
1470 /// Return true if the vector at position `a` intersects the vector at
1471 /// position `b`. Under insert/extract semantics, this is the same as equality
1472 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1473 /// Comparison is on the common prefix (i.e. zip).
1474 template <typename ContainerA, typename ContainerB>
1475 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1476 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1477 if (elemA < 0 || elemB < 0)
1478 continue;
1479 if (elemA != elemB)
1480 return false;
1481 }
1482 return true;
1483 }
1484
1485 /// Folding is only possible in the absence of an internal permutation in the
1486 /// result vector.
1487 bool canFold() {
1488 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1489 }
1490
1491 // Helper to get the next defining op of interest.
1492 void updateStateForNextIteration(Value v) {
1493 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1494 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1495 };
1496
1497 // Case 1. If we hit a transpose, just compose the map and iterate.
1498 // Invariant: insert + transpose do not change rank, we can always compose.
1499 LogicalResult handleTransposeOp();
1500
1501 // Case 2: the insert position matches extractPosition exactly, early return.
1502 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1503
1504 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1505 /// portion of the source of the insert.
1506 /// Example:
1507 /// ```
1508 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1509 /// // extractPosition == [1, 2, 3]
1510 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1511 /// // can fold to vector.extract %source[0, 3]
1512 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1513 /// ```
1514 /// To traverse through %source, we need to set the leading dims to 0 and
1515 /// drop the extra leading dims.
1516 /// This method updates the internal state.
1517 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1518
1519 /// Try to fold in place to extract(source, extractPosition) and return the
1520 /// folded result. Return null if folding is not possible (e.g. due to an
1521 /// internal transposition in the result).
1522 Value tryToFoldExtractOpInPlace(Value source);
1523
1524 ExtractOp extractOp;
1525 int64_t vectorRank;
1526 int64_t extractedRank;
1527
1528 InsertOp nextInsertOp;
1529 TransposeOp nextTransposeOp;
1530
1531 /// Sentinel values that encode the internal permutation status of the result.
1532 /// They are set to (-1, ... , -k) at the beginning and appended to
1533 /// `extractPosition`.
1534 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1535 /// ensure that there is no internal transposition.
1536 /// Internal transposition cannot be accounted for with a folding pattern.
1537 // TODO: We could relax the internal transposition with an extra transposition
1538 // operation in a future canonicalizer.
1539 SmallVector<int64_t> sentinels;
1540 SmallVector<int64_t> extractPosition;
1541};
1542} // namespace
1543
1544ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1545 ExtractOp e)
1546 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1547 extractedRank(extractOp.getNumIndices()) {
1548 assert(vectorRank >= extractedRank && "Extracted position overflow");
1549 sentinels.reserve(vectorRank - extractedRank);
1550 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1551 sentinels.push_back(-(i + 1));
1552 extractPosition.assign(extractOp.getStaticPosition().begin(),
1553 extractOp.getStaticPosition().end());
1554 llvm::append_range(extractPosition, sentinels);
1555}
1556
1557// Case 1. If we hit a transpose, just compose the map and iterate.
1558// Invariant: insert + transpose do not change rank, we can always compose.
1559LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1560 // TODO: Canonicalization for dynamic position not implemented yet.
1561 if (extractOp.hasDynamicPosition())
1562 return failure();
1563
1564 if (!nextTransposeOp)
1565 return failure();
1567 nextTransposeOp.getPermutation(), extractOp.getContext()));
1569 return success();
1570}
1571
1572// Case 2: the insert position matches extractPosition exactly, early return.
1573LogicalResult
1574ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1575 Value &res) {
1576 // TODO: Canonicalization for dynamic position not implemented yet.
1577 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1578 return failure();
1579
1580 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1581 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1582 return failure();
1583 // Case 2.a. early-exit fold.
1584 res = nextInsertOp.getValueToStore();
1585 // Case 2.b. if internal transposition is present, canFold will be false.
1586 return success(canFold());
1587}
1588
1589/// Case 3: if inserted position is a prefix of extractPosition,
1590/// extract a portion of the source of the insertion.
1591/// This method updates the internal state.
1592LogicalResult
1593ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1594 // TODO: Canonicalization for dynamic position not implemented yet.
1595 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1596 return failure();
1597
1598 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1599 if (!isContainedWithin(insertedPos, extractPosition))
1600 return failure();
1601 // Set leading dims to zero.
1602 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1603 // Drop extra leading dims.
1604 extractPosition.erase(extractPosition.begin(),
1605 extractPosition.begin() + insertedPos.size());
1606 extractedRank = extractPosition.size() - sentinels.size();
1607 // Case 3.a. early-exit fold (break and delegate to post-while path).
1608 res = nextInsertOp.getValueToStore();
1609 // Case 3.b. if internal transposition is present, canFold will be false.
1610 return success();
1611}
1612
1613/// Try to fold in place to extract(source, extractPosition) and return the
1614/// folded result. Return null if folding is not possible (e.g. due to an
1615/// internal transposition in the result).
1616Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1617 Value source) {
1618 // TODO: Canonicalization for dynamic position not implemented yet.
1619 if (extractOp.hasDynamicPosition())
1620 return Value();
1621
1622 // If we can't fold (either internal transposition, or nothing to fold), bail.
1623 bool nothingToFold = (source == extractOp.getSource());
1624 if (nothingToFold || !canFold())
1625 return Value();
1626
1627 // Otherwise, fold by updating the op inplace and return its result.
1628 OpBuilder b(extractOp.getContext());
1629 extractOp.setStaticPosition(
1630 ArrayRef(extractPosition).take_front(extractedRank));
1631 extractOp.getSourceMutable().assign(source);
1632 return extractOp.getResult();
1633}
1634
1635/// Iterate over producing insert and transpose ops until we find a fold.
1636Value ExtractFromInsertTransposeChainState::fold() {
1637 // TODO: Canonicalization for dynamic position not implemented yet.
1638 if (extractOp.hasDynamicPosition())
1639 return Value();
1640
1641 Value valueToExtractFrom = extractOp.getSource();
1642 updateStateForNextIteration(valueToExtractFrom);
1643 while (nextInsertOp || nextTransposeOp) {
1644 // Case 1. If we hit a transpose, just compose the map and iterate.
1645 // Invariant: insert + transpose do not change rank, we can always compose.
1646 if (succeeded(handleTransposeOp())) {
1647 valueToExtractFrom = nextTransposeOp.getVector();
1648 updateStateForNextIteration(valueToExtractFrom);
1649 continue;
1650 }
1651
1652 Value result;
1653 // Case 2: the position match exactly.
1654 if (succeeded(handleInsertOpWithMatchingPos(result)))
1655 return result;
1656
1657 // Case 3: if the inserted position is a prefix of extractPosition, we can
1658 // just extract a portion of the source of the insert.
1659 if (succeeded(handleInsertOpWithPrefixPos(result)))
1660 return tryToFoldExtractOpInPlace(result);
1661
1662 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1663 // values. This is a more difficult case and we bail.
1664 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1665 if (isContainedWithin(extractPosition, insertedPos) ||
1666 intersectsWhereNonNegative(extractPosition, insertedPos))
1667 return Value();
1668
1669 // Case 5: No intersection, we forward the extract to insertOp.dest().
1670 valueToExtractFrom = nextInsertOp.getDest();
1671 updateStateForNextIteration(valueToExtractFrom);
1672 }
1673 // If after all this we can fold, go for it.
1674 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1675}
1676
1677/// Returns true if the operation has a 0-D vector type operand or result.
1679 auto hasZeroDimVectorType = [](Type type) -> bool {
1680 auto vecType = dyn_cast<VectorType>(type);
1681 return vecType && vecType.getRank() == 0;
1682 };
1683
1684 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1685 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1686}
1687
1688/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1689/// considered to be 'broadcastlike'.
1690static bool isBroadcastLike(Operation *op) {
1691 if (isa<BroadcastOp>(op))
1692 return true;
1693
1694 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1695 if (!shapeCast)
1696 return false;
1697
1698 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1699 // Checking that the destination shape has a prefix of 1s is not sufficient,
1700 // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1701 // is that the source shape is a suffix of the destination shape.
1702 VectorType srcType = shapeCast.getSourceVectorType();
1703 ArrayRef<int64_t> srcShape = srcType.getShape();
1704 uint64_t srcRank = srcType.getRank();
1705 ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1706 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1707}
1708
1709/// Fold extract(broadcast(X)) to either extract(X) or just X.
1710///
1711/// Example:
1712///
1713/// broadcast extract [1][2]
1714/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1715///
1716/// becomes
1717/// extract [1]
1718/// (3,4) -------------------------------------> (4)
1719///
1720///
1721/// The variable names used in this implementation correspond to the above
1722/// shapes as,
1723///
1724/// - (3, 4) is `input` shape.
1725/// - (2, 3, 4) is `broadcast` shape.
1726/// - (4) is `extract` shape.
1727///
1728/// This folding is possible when the suffix of `input` shape is the same as
1729/// `extract` shape.
1730static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1731
1732 Operation *defOp = extractOp.getSource().getDefiningOp();
1733 if (!defOp || !isBroadcastLike(defOp))
1734 return Value();
1735
1736 Value input = defOp->getOperand(0);
1737
1738 // Replace extract(broadcast(X)) with X
1739 if (extractOp.getType() == input.getType())
1740 return input;
1741
1742 // Get required types and ranks in the chain
1743 // input -> broadcast -> extract
1744 // (scalars are treated as rank-0).
1745 auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1746 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1747 unsigned inputRank = inputType ? inputType.getRank() : 0;
1748 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1749 unsigned extractRank = extractType ? extractType.getRank() : 0;
1750
1751 // Cannot do without the broadcast if overall the rank increases.
1752 if (extractRank > inputRank)
1753 return Value();
1754
1755 // The above condition guarantees that input is a vector.
1756 assert(inputType && "input must be a vector type because of previous checks");
1757 ArrayRef<int64_t> inputShape = inputType.getShape();
1758
1759 // In the case where there is a broadcast dimension in the suffix, it is not
1760 // possible to replace extract(broadcast(X)) with extract(X). Example:
1761 //
1762 // broadcast extract
1763 // (1) --------> (3,4) ------> (4)
1764 if (extractType &&
1765 extractType.getShape() != inputShape.take_back(extractRank))
1766 return Value();
1767
1768 // Replace extract(broadcast(X)) with extract(X).
1769 // First, determine the new extraction position.
1770 unsigned deltaOverall = inputRank - extractRank;
1771 unsigned deltaBroadcast = broadcastRank - inputRank;
1772 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1773 SmallVector<OpFoldResult> newPositions(deltaOverall);
1774 IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1775 for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1776 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1777 }
1778 auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1779 extractOp->setOperands(
1780 llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1781 extractOp.setStaticPosition(staticPos);
1782 return extractOp.getResult();
1783}
1784
1785/// Fold extractOp coming from ShuffleOp.
1786///
1787/// Example:
1788///
1789/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1790/// : vector<8xf32>, vector<8xf32>
1791/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1792/// ->
1793/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1794///
1795static Value foldExtractFromShuffle(ExtractOp extractOp) {
1796 // Dynamic positions are not folded as the resulting code would be more
1797 // complex than the input code.
1798 if (extractOp.hasDynamicPosition())
1799 return Value();
1800
1801 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1802 if (!shuffleOp)
1803 return Value();
1804
1805 // TODO: 0-D or multi-dimensional vectors not supported yet.
1806 if (shuffleOp.getResultVectorType().getRank() != 1)
1807 return Value();
1808
1809 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1810 auto shuffleMask = shuffleOp.getMask();
1811 int64_t extractIdx = extractOp.getStaticPosition()[0];
1812 int64_t shuffleIdx = shuffleMask[extractIdx];
1813
1814 // Find the shuffled vector to extract from based on the shuffle index.
1815 if (shuffleIdx < inputVecSize) {
1816 extractOp.setOperand(0, shuffleOp.getV1());
1817 extractOp.setStaticPosition({shuffleIdx});
1818 } else {
1819 extractOp.setOperand(0, shuffleOp.getV2());
1820 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1821 }
1822
1823 return extractOp.getResult();
1824}
1825
1826// Fold extractOp with source coming from ShapeCast op.
1827static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1828 // TODO: Canonicalization for dynamic position not implemented yet.
1829 if (extractOp.hasDynamicPosition())
1830 return Value();
1831
1832 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1833 if (!shapeCastOp)
1834 return Value();
1835
1836 // Get the nth dimension size starting from lowest dimension.
1837 auto getDimReverse = [](VectorType type, int64_t n) {
1838 return type.getShape().take_back(n + 1).front();
1839 };
1840 int64_t destinationRank =
1841 llvm::isa<VectorType>(extractOp.getType())
1842 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1843 : 0;
1844 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1845 return Value();
1846 if (destinationRank > 0) {
1847 auto destinationType =
1848 llvm::cast<VectorType>(extractOp.getResult().getType());
1849 for (int64_t i = 0; i < destinationRank; i++) {
1850 // The lowest dimension of the destination must match the lowest
1851 // dimension of the shapecast op source.
1852 // TODO: This case could be support in a canonicalization pattern.
1853 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1854 getDimReverse(destinationType, i))
1855 return Value();
1856 }
1857 }
1858 // Extract the strides associated with the extract op vector source. Then use
1859 // this to calculate a linearized position for the extract.
1860 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1861 std::reverse(extractedPos.begin(), extractedPos.end());
1863 int64_t stride = 1;
1864 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1865 strides.push_back(stride);
1866 stride *=
1867 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1868 }
1869
1870 int64_t position = linearize(extractedPos, strides);
1871 // Then extract the strides associated to the shapeCast op vector source and
1872 // delinearize the position using those strides.
1873 SmallVector<int64_t, 4> newStrides;
1874 int64_t numDimension =
1875 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1876 stride = 1;
1877 for (int64_t i = 0; i < numDimension; i++) {
1878 newStrides.push_back(stride);
1879 stride *=
1880 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1881 }
1882 std::reverse(newStrides.begin(), newStrides.end());
1883 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1884 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1885 OpBuilder b(extractOp.getContext());
1886 extractOp.setStaticPosition(newPosition);
1887 extractOp.setOperand(0, shapeCastOp.getSource());
1888 return extractOp.getResult();
1889}
1890
1891/// Fold an ExtractOp from ExtractStridedSliceOp.
1892static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1893 // TODO: Canonicalization for dynamic position not implemented yet.
1894 if (extractOp.hasDynamicPosition())
1895 return Value();
1896
1897 auto extractStridedSliceOp =
1898 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1899 if (!extractStridedSliceOp)
1900 return Value();
1901
1902 // 0-D vectors not supported.
1903 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1904 if (hasZeroDimVectors(extractStridedSliceOp))
1905 return Value();
1906
1907 // Return if 'extractStridedSliceOp' has non-unit strides.
1908 if (extractStridedSliceOp.hasNonUnitStrides())
1909 return Value();
1910
1911 // Trim offsets for dimensions fully extracted.
1912 auto sliceOffsets =
1913 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1914 while (!sliceOffsets.empty()) {
1915 size_t lastOffset = sliceOffsets.size() - 1;
1916 if (sliceOffsets.back() != 0 ||
1917 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1918 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1919 break;
1920 sliceOffsets.pop_back();
1921 }
1922 unsigned destinationRank = 0;
1923 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1924 destinationRank = vecType.getRank();
1925 // The dimensions of the result need to be untouched by the
1926 // extractStridedSlice op.
1927 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1928 sliceOffsets.size())
1929 return Value();
1930
1931 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1932 assert(extractedPos.size() >= sliceOffsets.size());
1933 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1934 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1935 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1936
1937 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1938 OpBuilder b(extractOp.getContext());
1939 extractOp.setStaticPosition(extractedPos);
1940 return extractOp.getResult();
1941}
1942
1943/// Fold extract_op fed from a chain of insertStridedSlice ops.
1944static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1945 // TODO: Canonicalization for dynamic position not implemented yet.
1946 if (extractOp.hasDynamicPosition())
1947 return Value();
1948
1949 int64_t destinationRank =
1950 llvm::isa<VectorType>(extractOp.getType())
1951 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1952 : 0;
1953 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1954 if (!insertOp)
1955 return Value();
1956
1957 // 0-D vectors not supported.
1958 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1959 if (hasZeroDimVectors(insertOp))
1960 return Value();
1961
1962 while (insertOp) {
1963 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1964 insertOp.getSourceVectorType().getRank();
1965 if (destinationRank > insertOp.getSourceVectorType().getRank())
1966 return Value();
1967 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1968 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1969
1970 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1971 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1972 }))
1973 return Value();
1974 bool disjoint = false;
1975 SmallVector<int64_t, 4> offsetDiffs;
1976 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1977 int64_t start = insertOffsets[dim];
1978 int64_t size =
1979 (dim < insertRankDiff)
1980 ? 1
1981 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1982 int64_t end = start + size;
1983 int64_t offset = extractOffsets[dim];
1984 // Check if the start of the extract offset is in the interval inserted.
1985 if (start <= offset && offset < end) {
1986 if (dim >= insertRankDiff)
1987 offsetDiffs.push_back(offset - start);
1988 continue;
1989 }
1990 disjoint = true;
1991 break;
1992 }
1993 // The extract element chunk overlap with the vector inserted.
1994 if (!disjoint) {
1995 // If any of the inner dimensions are only partially inserted we have a
1996 // partial overlap.
1997 int64_t srcRankDiff =
1998 insertOp.getSourceVectorType().getRank() - destinationRank;
1999 for (int64_t i = 0; i < destinationRank; i++) {
2000 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
2001 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
2002 insertRankDiff))
2003 return Value();
2004 }
2005 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2006 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2007 OpBuilder b(extractOp.getContext());
2008 extractOp.setStaticPosition(offsetDiffs);
2009 return extractOp.getResult();
2010 }
2011 // If the chunk extracted is disjoint from the chunk inserted, keep
2012 // looking in the insert chain.
2013 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2014 }
2015 return Value();
2016}
2017
2018/// Try to fold the extraction of a scalar from a vector defined by
2019/// vector.from_elements. E.g.:
2020///
2021/// %0 = vector.from_elements %a, %b : vector<2xf32>
2022/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2023/// ==> fold to %a
2024static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2025 // Dynamic extractions cannot be folded.
2026 if (extractOp.hasDynamicPosition())
2027 return {};
2028
2029 // Look for extract(from_elements).
2030 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2031 if (!fromElementsOp)
2032 return {};
2033
2034 // Scalable vectors are not supported.
2035 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2036 if (vecType.isScalable())
2037 return {};
2038
2039 // Only extractions of scalars are supported.
2040 int64_t rank = vecType.getRank();
2041 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2042 if (extractOp.getType() != vecType.getElementType())
2043 return {};
2044 assert(static_cast<int64_t>(indices.size()) == rank &&
2045 "unexpected number of indices");
2046
2047 // Compute flattened/linearized index and fold to operand.
2048 int flatIndex = 0;
2049 int stride = 1;
2050 for (int i = rank - 1; i >= 0; --i) {
2051 flatIndex += indices[i] * stride;
2052 stride *= vecType.getDimSize(i);
2053 }
2054 return fromElementsOp.getElements()[flatIndex];
2055}
2056
2057/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2058/// then fold it.
2059template <typename OpType, typename AdaptorType>
2060static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2061 SmallVectorImpl<Value> &operands) {
2062 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2063 OperandRange dynamicPosition = op.getDynamicPosition();
2064 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2066 if constexpr (std::is_same_v<OpType, ExtractOp>)
2067 vectorShape = op.getSourceVectorType().getShape();
2068 else
2069 vectorShape = op.getDestVectorType().getShape();
2070
2071 // If the dynamic operands is empty, it is returned directly.
2072 if (!dynamicPosition.size())
2073 return {};
2074
2075 // `index` is used to iterate over the `dynamicPosition`.
2076 unsigned index = 0;
2077
2078 // `opChange` is a flag. If it is true, it means to update `op` in place.
2079 bool opChange = false;
2080 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2081 if (ShapedType::isStatic(staticPosition[i]))
2082 continue;
2083 Attribute positionAttr = dynamicPositionAttr[index];
2084 Value position = dynamicPosition[index++];
2085 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2086 int64_t value = attr.getInt();
2087 // Do not fold if the value is out of bounds (-1 signifies a poison
2088 // value rather than OOB index).
2089 if (value >= -1 && value < vectorShape[i]) {
2090 staticPosition[i] = attr.getInt();
2091 opChange = true;
2092 continue;
2093 }
2094 }
2095 operands.push_back(position);
2096 }
2097
2098 if (opChange) {
2099 op.setStaticPosition(staticPosition);
2100 op.getOperation()->setOperands(operands);
2101 // Return the original result to indicate an in-place folding happened.
2102 return op.getResult();
2103 }
2104 return {};
2105}
2106
2107/// Fold an insert or extract operation into an poison value when a poison index
2108/// is found at any dimension of the static position.
2110 ArrayRef<int64_t> staticPos,
2111 int64_t poisonVal) {
2112 if (!is_contained(staticPos, poisonVal))
2113 return {};
2114
2115 return ub::PoisonAttr::get(context);
2116}
2117
2118/// Fold a vector extract from is a poison source.
2120 if (matchPattern(srcAttr, ub::m_Poison()))
2121 return srcAttr;
2122
2123 return {};
2124}
2125
2126/// Fold a vector extract extracting from a DenseElementsAttr.
2128 Attribute srcAttr) {
2129 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2130 if (!denseAttr) {
2131 return {};
2132 }
2133
2134 if (denseAttr.isSplat()) {
2135 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2136 if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2137 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2138 return newAttr;
2139 }
2140
2141 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2142 if (vecTy.isScalable())
2143 return {};
2144
2145 if (extractOp.hasDynamicPosition()) {
2146 return {};
2147 }
2148
2149 // Materializing subsets of a large constant array can generally lead to
2150 // explosion in IR size because of different combination of subsets that
2151 // can exist. However, vector.extract is a restricted form of subset
2152 // extract where you can only extract non-overlapping (or the same) subset for
2153 // a given rank of the subset. Because of this property, the IR size can only
2154 // increase at most by `rank * size(array)` from a single constant array being
2155 // extracted by multiple extracts.
2156
2157 // Calculate the linearized position of the continuous chunk of elements to
2158 // extract.
2159 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2160 copy(extractOp.getStaticPosition(), completePositions.begin());
2161 int64_t startPos =
2162 linearize(completePositions, computeStrides(vecTy.getShape()));
2163 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2164
2165 TypedAttr newAttr;
2166 if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2167 SmallVector<Attribute> elementValues(
2168 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2169 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2170 } else {
2171 newAttr = *denseValuesBegin;
2172 }
2173
2174 return newAttr;
2175}
2176
2177OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2178 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2179 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2180 // mismatch).
2181 if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2182 return getSource();
2183 if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2184 return res;
2185 // Fold `arith.constant` indices into the `vector.extract` operation.
2186 // Do not stop here as this fold may enable subsequent folds that require
2187 // constant indices.
2188 SmallVector<Value> operands = {getSource()};
2189 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2190
2191 if (auto res = foldPoisonIndexInsertExtractOp(
2192 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2193 return res;
2194 if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2195 return res;
2196 if (succeeded(foldExtractOpFromExtractChain(*this)))
2197 return getResult();
2198 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2199 return res;
2200 if (auto res = foldExtractFromBroadcast(*this))
2201 return res;
2202 if (auto res = foldExtractFromShuffle(*this))
2203 return res;
2204 if (auto res = foldExtractFromShapeCast(*this))
2205 return res;
2206 if (auto val = foldExtractFromExtractStrided(*this))
2207 return val;
2208 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2209 return val;
2210 if (auto val = foldScalarExtractFromFromElements(*this))
2211 return val;
2212
2213 return inplaceFolded;
2214}
2215
2216namespace {
2217
2218// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2219class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2220public:
2221 using Base::Base;
2222
2223 LogicalResult matchAndRewrite(ExtractOp extractOp,
2224 PatternRewriter &rewriter) const override {
2225
2226 Operation *defOp = extractOp.getSource().getDefiningOp();
2227 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2228 if (!defOp || !isBroadcastLike(defOp) || !outType)
2229 return failure();
2230
2231 Value source = defOp->getOperand(0);
2232 if (isBroadcastableTo(source.getType(), outType) !=
2233 BroadcastableToResult::Success)
2234 return failure();
2235
2236 rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2237 return success();
2238 }
2239};
2240
2241// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2242class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2243public:
2244 using Base::Base;
2245
2246 LogicalResult matchAndRewrite(ExtractOp extractOp,
2247 PatternRewriter &rewriter) const override {
2248 auto createMaskOp =
2249 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2250 if (!createMaskOp)
2251 return failure();
2252
2253 VectorType extractedMaskType =
2254 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2255
2256 if (!extractedMaskType)
2257 return failure();
2258
2259 auto maskOperands = createMaskOp.getOperands();
2260 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2261 VectorType maskType = createMaskOp.getVectorType();
2262
2263 bool containsUnknownDims = false;
2264 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2265
2266 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2267 dimIdx++) {
2268 int64_t pos = extractOpPos[dimIdx];
2269 Value operand = maskOperands[dimIdx];
2270 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2271 if (!constantOp) {
2272 // Bounds of this dim unknown.
2273 containsUnknownDims = true;
2274 continue;
2275 }
2276
2277 int64_t createMaskBound =
2278 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2279
2280 if (pos != ShapedType::kDynamic) {
2281 // If any position is outside the range from the `create_mask`, then the
2282 // extracted mask will be all-false.
2283 allFalse |= pos >= createMaskBound;
2284 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2285 // This dim is not all-true and since this is a dynamic index we don't
2286 // know if the extraction is within the true or false region.
2287 // Note: Zero dims have already handled via getMaskFormat().
2288 containsUnknownDims = true;
2289 }
2290 }
2291
2292 if (allFalse) {
2293 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2294 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2295 } else if (!containsUnknownDims) {
2296 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2297 extractOp, extractedMaskType,
2298 maskOperands.drop_front(extractOpPos.size()));
2299 } else {
2300 return failure();
2301 }
2302 return success();
2303 }
2304};
2305
2306// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
2307class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
2308public:
2309 using Base::Base;
2310
2311 LogicalResult matchAndRewrite(ExtractOp extractOp,
2312 PatternRewriter &rewriter) const override {
2313 auto constantMaskOp =
2314 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2315 if (!constantMaskOp)
2316 return failure();
2317
2318 Type resultType = extractOp.getResult().getType();
2319 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2320
2321 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2322 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2323
2324 VectorType maskType = constantMaskOp.getVectorType();
2325
2326 // Check if any extracted position is outside the mask bounds.
2327 for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2328 int64_t pos = extractOpPos[dimIdx];
2329 if (pos == ShapedType::kDynamic) {
2330 // If the dim is all-true, a dynamic index is fine — any position
2331 // is within the masked region.
2332 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2333 continue;
2334 // Otherwise we don't know if the position is inside or outside of
2335 // the masked area, so bail out.
2336 return failure();
2337 }
2338
2339 // If the position is statically outside of the masked area, the result
2340 // will be all-false.
2341 if (pos >= maskDimSizes[dimIdx]) {
2342 if (extractedMaskType) {
2343 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2344 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2345 } else {
2346 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2347 extractOp, rewriter.getIntegerAttr(resultType, false));
2348 }
2349 return success();
2350 }
2351 }
2352
2353 // All positions are within the mask bounds.
2354 if (extractedMaskType) {
2355 // Vector result: the result is a constant_mask with the remaining
2356 // dimensions.
2357 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
2358 extractOp, extractedMaskType,
2359 maskDimSizes.drop_front(extractOpPos.size()));
2360 } else {
2361 // Scalar result: all positions are within the masked region, so the
2362 // result is true.
2363 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2364 extractOp, rewriter.getIntegerAttr(resultType, true));
2365 }
2366 return success();
2367 }
2368};
2369
2370// Folds extract(shape_cast(..)) into shape_cast when the total element count
2371// does not change.
2372LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2373 PatternRewriter &rewriter) {
2374 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2375 if (!castOp)
2376 return failure();
2377
2378 VectorType sourceType = castOp.getSourceVectorType();
2379 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2380 if (!targetType)
2381 return failure();
2382
2383 if (sourceType.getNumElements() != targetType.getNumElements())
2384 return failure();
2385
2386 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2387 castOp.getSource());
2388 return success();
2389}
2390
2391/// Try to canonicalize the extraction of a subvector from a vector defined by
2392/// vector.from_elements. E.g.:
2393///
2394/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2395/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2396/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2397LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2398 PatternRewriter &rewriter) {
2399 // Dynamic positions are not supported.
2400 if (extractOp.hasDynamicPosition())
2401 return failure();
2402
2403 // Scalar extracts are handled by the folder.
2404 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2405 if (!resultType)
2406 return failure();
2407
2408 // Look for extracts from a from_elements op.
2409 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2410 if (!fromElementsOp)
2411 return failure();
2412 VectorType inputType = fromElementsOp.getType();
2413
2414 // Scalable vectors are not supported.
2415 if (resultType.isScalable() || inputType.isScalable())
2416 return failure();
2417
2418 // Compute the position of first extracted element and flatten/linearize the
2419 // position.
2420 SmallVector<int64_t> firstElementPos =
2421 llvm::to_vector(extractOp.getStaticPosition());
2422 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2423 int flatIndex = 0;
2424 int stride = 1;
2425 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2426 flatIndex += firstElementPos[i] * stride;
2427 stride *= inputType.getDimSize(i);
2428 }
2429
2430 // Replace the op with a smaller from_elements op.
2431 rewriter.replaceOpWithNewOp<FromElementsOp>(
2432 extractOp, resultType,
2433 fromElementsOp.getElements().slice(flatIndex,
2434 resultType.getNumElements()));
2435 return success();
2436}
2437
2438/// Replace `vector.extract` with `vector.shape_cast`.
2439///
2440/// BEFORE:
2441/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2442/// AFTER:
2443/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2444///
2445/// The canonical form of vector operations that reshape vectors is shape_cast.
2446struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2447 using Base::Base;
2448 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2449 PatternRewriter &rewriter) const override {
2450 VectorType sourceType = extractOp.getSourceVectorType();
2451 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2452 if (!outType)
2453 return failure();
2454
2455 if (sourceType.getNumElements() != outType.getNumElements())
2456 return rewriter.notifyMatchFailure(
2457 extractOp, "extract to vector with fewer elements");
2458
2459 // Negative values in `position` means that the extacted value is poison.
2460 // There is a vector.extract folder for this.
2461 if (llvm::any_of(extractOp.getMixedPosition(),
2462 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2463 return rewriter.notifyMatchFailure(extractOp,
2464 "leaving for extract poison folder");
2465
2466 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2467 extractOp.getSource());
2468
2469 return success();
2470 }
2471};
2472
2473/// Folds vector.extract from vector.insert when the extract position is a
2474/// prefix of the insert position and the remaining (un-indexed) dimensions
2475/// of the extracted sub-vector are all size 1. In that case the extracted
2476/// value is fully determined by the inserted value.
2477///
2478/// Examples:
2479/// %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
2480/// %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
2481/// folds to:
2482/// %ext = vector.broadcast %s : f32 to vector<1xf32>
2483///
2484/// %ins = vector.insert %s, %v [0, 0] : vector<1xf32> into vector<16x1x1xf32>
2485// %ext = vector.extract %ins [0] : vector<1x1xf32> from vector<16x1x1xf32>
2486/// folds to:
2487/// %ext = vector.shape_cast %arg0 : vector<1xf32> to vector<1x1xf32>
2488struct FoldExtractFromInsertUnitDim final
2489 : OpRewritePattern<vector::ExtractOp> {
2490 using Base::Base;
2491
2492 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2493 PatternRewriter &rewriter) const override {
2494 if (extractOp.hasDynamicPosition())
2495 return failure();
2496
2497 auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
2498 if (!insertOp || insertOp.hasDynamicPosition())
2499 return failure();
2500
2501 ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
2502 ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
2503
2504 // The extract position must be a strict prefix of the insert position.
2505 if (extractPos.size() >= insertPos.size() ||
2506 extractPos != insertPos.take_front(extractPos.size()))
2507 return failure();
2508
2509 // The remaining dimensions (those not indexed by the extract) must all
2510 // be size 1 in the source vector type. This guarantees that the inserted
2511 // value fully determines the extracted sub-vector.
2512 auto srcVecType = extractOp.getSourceVectorType();
2513 for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
2514 if (srcVecType.getDimSize(i) != 1)
2515 return failure();
2516
2517 Value inserted = insertOp.getValueToStore();
2518 Type extractedType = extractOp.getResult().getType();
2519 if (isa<VectorType>(inserted.getType())) {
2520 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, extractedType,
2521 inserted);
2522 } else {
2523 // The inserted value fully determines the extracted sub-vector; broadcast
2524 // it to the extracted type.
2525 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2526 extractOp, extractOp.getResult().getType(),
2527 insertOp.getValueToStore());
2528 }
2529 return success();
2530 }
2531};
2532
2533} // namespace
2534
2535void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2536 MLIRContext *context) {
2537 results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2538 ExtractOpFromConstantMask, ExtractToShapeCast,
2539 FoldExtractFromInsertUnitDim>(context);
2540 results.add(foldExtractFromShapeCastToShapeCast);
2541 results.add(foldExtractFromFromElements);
2542}
2543
2545 SmallVectorImpl<int64_t> &results) {
2546 for (auto attr : arrayAttr)
2547 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2548}
2549
2550//===----------------------------------------------------------------------===//
2551// FmaOp
2552//===----------------------------------------------------------------------===//
2553
2554std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2555 return llvm::to_vector<4>(getVectorType().getShape());
2556}
2557
2558//===----------------------------------------------------------------------===//
2559// ToElementsOp
2560//===----------------------------------------------------------------------===//
2561
2562/// Returns true if all the `operands` are defined by `defOp`.
2563/// Otherwise, returns false.
2564static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2565 if (operands.empty())
2566 return false;
2567
2568 return llvm::all_of(operands, [&](Value operand) {
2569 Operation *currentDef = operand.getDefiningOp();
2570 return currentDef == defOp;
2571 });
2572}
2573
2574/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2575/// (%e0, %e1, ...). For example:
2576///
2577/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2578/// %1:3 = vector.to_elements %0 : vector<3xf32>
2579/// user_op %1#0, %1#1, %1#2
2580///
2581/// becomes:
2582///
2583/// user_op %a, %b, %c
2584///
2585static LogicalResult
2586foldToElementsFromElements(ToElementsOp toElementsOp,
2588 auto fromElementsOp =
2589 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2590 if (!fromElementsOp)
2591 return failure();
2592
2593 llvm::append_range(results, fromElementsOp.getElements());
2594 return success();
2595}
2596
2597/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2598///
2599/// Example:
2600/// %b = vector.broadcast %x : i32 to vector<3xf32>
2601/// %e:3 = vector.to_elements %b : vector<3xf32>
2602/// user_op %e#0, %e#1, %e#2
2603/// becomes:
2604/// user_op %x, %x, %x
2605///
2606/// The vector source case is handled by a canonicalization pattern.
2607static LogicalResult
2608foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2610 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2611 if (!bcastOp)
2612 return failure();
2613 // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2614 if (isa<VectorType>(bcastOp.getSource().getType()))
2615 return failure();
2616
2617 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2618
2619 Value scalar = bcastOp.getSource();
2620 results.assign(resultVecType.getNumElements(), scalar);
2621 return success();
2622}
2623
2624LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2625 SmallVectorImpl<OpFoldResult> &results) {
2626 if (succeeded(foldToElementsFromElements(*this, results)))
2627 return success();
2628
2629 // Y = ToElements(ShapeCast(X)) -> Y = ToElements(X)
2630 if (auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2631 setOperand(shapeCast.getSource());
2632 return success();
2633 }
2634
2635 return foldToElementsOfBroadcast(*this, results);
2636}
2637
2638LogicalResult
2639ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2640 ToElementsOp::Adaptor adaptor,
2641 SmallVectorImpl<Type> &inferredReturnTypes) {
2642 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2643 Type elType = vecType.getElementType();
2644 inferredReturnTypes.append(vecType.getNumElements(), elType);
2645 return success();
2646}
2647
2648/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2649/// vector.
2650/// - Build `vector.to_elements %v` and remap each destination element to the
2651/// corresponding source element using broadcast rules (match or 1 →
2652/// replicate).
2653///
2654/// Example:
2655/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2656/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2657/// becomes:
2658/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2659/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2660/// // %src_elems#1, %src_elems#0, %src_elems#1
2661struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2662 using Base::Base;
2663
2664 LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2665 PatternRewriter &rewriter) const override {
2666 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2667 if (!bcastOp)
2668 return failure();
2669
2670 // Only handle broadcasts from a vector source here.
2671 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2672 if (!srcType)
2673 return failure();
2674
2675 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2676
2677 ArrayRef<int64_t> dstShape = dstType.getShape();
2678 ArrayRef<int64_t> srcShape = srcType.getShape();
2679
2680 int64_t dstRank = dstShape.size();
2681 int64_t srcRank = srcShape.size();
2682
2683 // Create elements for the broadcast source vector.
2684 auto srcElems = vector::ToElementsOp::create(
2685 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2686
2687 int64_t dstCount = llvm::product_of(dstShape);
2688
2689 SmallVector<Value> replacements;
2690 replacements.reserve(dstCount);
2691
2692 // For each element of the destination, determine which element of the
2693 // source should be used. We walk all destination positions using a single
2694 // counter, decode it into per-dimension indices, then build the matching
2695 // source position: use the same index where sizes match, and use 0 where
2696 // the source size is 1 (replication). This mapping is needed so we can
2697 // replace each result of to_elements with the corresponding element from
2698 // the broadcast source.
2699 // Inner-dimension stretch example:
2700 // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2701 // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2702 // becomes:
2703 // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2704 // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2705 // // %src_elems#1, %src_elems#0, %src_elems#1,
2706 // // %src_elems#2, %src_elems#3, %src_elems#2,
2707 // // %src_elems#3, %src_elems#2, %src_elems#3
2708
2709 // Row-major strides for the destination shape.
2710 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2711 // Row-major strides for the source shape.
2712 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2713 SmallVector<int64_t> dstIdx(dstRank);
2714 SmallVector<int64_t> srcIdx(srcRank);
2715 for (int64_t lin = 0; lin < dstCount; ++lin) {
2716 // Convert linear destination index to per-dimension indices.
2717 dstIdx = delinearize(lin, dstStrides);
2718 for (int64_t k = 0; k < srcRank; ++k)
2719 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2720 // Convert per-dimension source indices back to a linear index.
2721 int64_t srcLin = linearize(srcIdx, srcStrides);
2722 replacements.push_back(srcElems.getResult(srcLin));
2723 }
2724
2725 rewriter.replaceOp(toElementsOp, replacements);
2726 return success();
2727 }
2728};
2729
2730void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2731 MLIRContext *context) {
2732 results.add<ToElementsOfBroadcast>(context);
2733}
2734
2735//===----------------------------------------------------------------------===//
2736// FromElementsOp
2737//===----------------------------------------------------------------------===//
2738
2739/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2740///
2741/// Case #1: Input and output vectors are the same.
2742///
2743/// %0:3 = vector.to_elements %a : vector<3xf32>
2744/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2745/// user_op %1
2746///
2747/// becomes:
2748///
2749/// user_op %a
2750///
2751static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2752 OperandRange fromElemsOperands = fromElementsOp.getElements();
2753 if (fromElemsOperands.empty())
2754 return {};
2755
2756 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2757 if (!toElementsOp)
2758 return {};
2759
2760 if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2761 return {};
2762
2763 // Case #1: Input and output vectors are the same. Forward the input vector.
2764 Value toElementsInput = toElementsOp.getSource();
2765 if (fromElementsOp.getType() == toElementsInput.getType() &&
2766 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2767 return toElementsInput;
2768 }
2769
2770 // TODO: Support cases with different input and output shapes and different
2771 // number of elements.
2772
2773 return {};
2774}
2775
2776/// Fold vector.from_elements to a constant when all operands are constants.
2777/// Example:
2778/// %c1 = arith.constant 1 : i32
2779/// %c2 = arith.constant 2 : i32
2780/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2781/// =>
2782/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2783///
2784static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2785 ArrayRef<Attribute> elements) {
2786 // Check for null or poison attributes before any processing.
2787 if (llvm::any_of(elements, [](Attribute attr) {
2788 return !attr || matchPattern(attr, ub::m_Poison());
2789 }))
2790 return {};
2791
2792 // DenseElementsAttr only supports int/index/float/complex types.
2793 auto destVecType = fromElementsOp.getDest().getType();
2794 auto destEltType = destVecType.getElementType();
2795 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2796 return {};
2797
2798 // Constant attributes might have a different type than the return type.
2799 // Convert them before creating the dense elements attribute.
2800 auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2801 return convertNumericAttr(attr, destEltType);
2802 });
2803
2804 return DenseElementsAttr::get(destVecType, convertedElements);
2805}
2806
2807OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2808 if (auto res = foldFromElementsToElements(*this))
2809 return res;
2810 if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2811 return res;
2812
2813 return {};
2814}
2815
2816/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2817/// same. Example:
2818/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2819/// =>
2820/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2821static LogicalResult
2822rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2823 PatternRewriter &rewriter) {
2824 if (!llvm::all_equal(fromElementsOp.getElements()))
2825 return failure();
2826 rewriter.replaceOpWithNewOp<BroadcastOp>(
2827 fromElementsOp, fromElementsOp.getType(),
2828 fromElementsOp.getElements().front());
2829 return success();
2830}
2831
2832/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2833/// on a single extract. Example:
2834/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2835/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2836/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2837///
2838/// becomes
2839/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2840/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2841///
2842/// The requirements for this to be valid are
2843///
2844/// i) The elements are extracted from the same vector (%source).
2845///
2846/// ii) The elements form a suffix of %source. Specifically, the number
2847/// of elements is the same as the product of the last N dimension sizes
2848/// of %source, for some N.
2849///
2850/// iii) The elements are extracted contiguously in ascending order.
2851
2852class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2853
2854 using Base::Base;
2855
2856 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2857 PatternRewriter &rewriter) const override {
2858
2859 // Handled by `rewriteFromElementsAsBroadcast`.
2860 if (fromElements.getType().getNumElements() == 1)
2861 return failure();
2862
2863 // The common source that all elements are extracted from, if one exists.
2865 // The position of the combined extract operation, if one is created.
2866 ArrayRef<int64_t> combinedPosition;
2867 // The expected index of extraction of the current element in the loop, if
2868 // elements are extracted contiguously in ascending order.
2869 SmallVector<int64_t> expectedPosition;
2870
2871 for (auto [insertIndex, element] :
2872 llvm::enumerate(fromElements.getElements())) {
2873
2874 // Check that the element is from a vector.extract operation.
2875 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2876 if (!extractOp) {
2877 return rewriter.notifyMatchFailure(fromElements,
2878 "element not from vector.extract");
2879 }
2880
2881 // Check condition (i) by checking that all elements have the same source
2882 // as the first element.
2883 if (insertIndex == 0) {
2884 source = extractOp.getSource();
2885 } else if (extractOp.getSource() != source) {
2886 return rewriter.notifyMatchFailure(fromElements,
2887 "element from different vector");
2888 }
2889
2890 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2891 int64_t rank = position.size();
2892 assert(rank == source.getType().getRank() &&
2893 "scalar extract must have full rank position");
2894
2895 // Check condition (ii) by checking that the position that the first
2896 // element is extracted from has sufficient trailing 0s. For example, in
2897 //
2898 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2899 // [...]
2900 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2901 //
2902 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2903 // elements, which is the number of elements of %n, so this is valid.
2904 if (insertIndex == 0) {
2905 const int64_t numElms = fromElements.getType().getNumElements();
2906 int64_t numSuffixElms = 1;
2907 int64_t index = rank;
2908 while (index > 0 && position[index - 1] == 0 &&
2909 numSuffixElms < numElms) {
2910 numSuffixElms *= source.getType().getDimSize(index - 1);
2911 --index;
2912 }
2913 if (numSuffixElms != numElms) {
2914 return rewriter.notifyMatchFailure(
2915 fromElements, "elements do not form a suffix of source");
2916 }
2917 expectedPosition = llvm::to_vector(position);
2918 combinedPosition = position.drop_back(rank - index);
2919 }
2920
2921 // Check condition (iii).
2922 else if (expectedPosition != position) {
2923 return rewriter.notifyMatchFailure(
2924 fromElements, "elements not in ascending order (static order)");
2925 }
2926 increment(expectedPosition, source.getType().getShape());
2927 }
2928
2929 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2930 fromElements.getLoc(), source, combinedPosition);
2931
2932 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2933 fromElements, fromElements.getType(), extracted);
2934
2935 return success();
2936 }
2937
2938 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2939 static void increment(MutableArrayRef<int64_t> indices,
2941 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2942 indices[dim] += 1;
2943 if (indices[dim] < shape[dim])
2944 break;
2945 indices[dim] = 0;
2946 }
2947 }
2948};
2949
2950void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2951 MLIRContext *context) {
2953 results.add<FromElementsToShapeCast>(context);
2954}
2955
2956//===----------------------------------------------------------------------===//
2957// BroadcastOp
2958//===----------------------------------------------------------------------===//
2959
2960void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2961 SetIntRangeFn setResultRanges) {
2962 setResultRanges(getResult(), argRanges.front());
2963}
2964
2965std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2966 return llvm::to_vector<4>(getResultVectorType().getShape());
2967}
2968
2969/// Return the dimensions of the result vector that were formerly ones in the
2970/// source tensor and thus correspond to "dim-1" broadcasting.
2971static llvm::SetVector<int64_t>
2973 ArrayRef<int64_t> dstShape) {
2974 int64_t rankDiff = dstShape.size() - srcShape.size();
2975 int64_t dstDim = rankDiff;
2977 for (auto [s1, s2] :
2978 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2979 if (s1 != s2) {
2980 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2981 res.insert(dstDim);
2982 }
2983 ++dstDim;
2984 }
2985 return res;
2986}
2987
2988llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2989 // Scalar broadcast is without any unit dim broadcast.
2990 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2991 if (!srcVectorType)
2992 return {};
2993 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2994 getResultVectorType().getShape());
2995}
2996
2997/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2998/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2999/// This requires (and asserts) that the broadcast is free of "dim-1"
3000/// broadcasting.
3001/// Since vector.broadcast only allows expanding leading dimensions, an extra
3002/// vector.transpose may be inserted to make the broadcast possible.
3003/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
3004/// the helper will assert. This means:
3005/// 1. `dstShape` must not be empty.
3006/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
3007/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
3008// must match the `value` shape.
3009Value BroadcastOp::createOrFoldBroadcastOp(
3010 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
3011 const llvm::SetVector<int64_t> &broadcastedDims) {
3012 assert(!dstShape.empty() && "unexpected empty dst shape");
3013
3014 // Well-formedness check.
3015 SmallVector<int64_t> checkShape;
3016 for (int i = 0, e = dstShape.size(); i < e; ++i) {
3017 if (broadcastedDims.contains(i))
3018 continue;
3019 checkShape.push_back(dstShape[i]);
3020 }
3021 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
3022 "ill-formed broadcastedDims contains values not confined to "
3023 "destVectorShape");
3024
3025 Location loc = value.getLoc();
3026 Type elementType = getElementTypeOrSelf(value.getType());
3027 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
3028 VectorType dstVectorType = VectorType::get(dstShape, elementType);
3029
3030 // Step 2. If scalar -> dstShape broadcast, just do it.
3031 if (!srcVectorType) {
3032 assert(checkShape.empty() &&
3033 "ill-formed createOrFoldBroadcastOp arguments");
3034 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
3035 }
3036
3037 assert(srcVectorType.getShape().equals(checkShape) &&
3038 "ill-formed createOrFoldBroadcastOp arguments");
3039
3040 // Step 3. Since vector.broadcast only allows creating leading dims,
3041 // vector -> dstShape broadcast may require a transpose.
3042 // Traverse the dims in order and construct:
3043 // 1. The leading entries of the broadcastShape that is guaranteed to be
3044 // achievable by a simple broadcast.
3045 // 2. The induced permutation for the subsequent vector.transpose that will
3046 // bring us from `broadcastShape` back to he desired `dstShape`.
3047 // If the induced permutation is not the identity, create a vector.transpose.
3048 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
3049 broadcastShape.reserve(dstShape.size());
3050 // Consider the example:
3051 // srcShape = 2x4
3052 // dstShape = 1x2x3x4x5
3053 // broadcastedDims = [0, 2, 4]
3054 //
3055 // We want to build:
3056 // broadcastShape = 1x3x5x2x4
3057 // permutation = [0, 2, 4, 1, 3]
3058 // ---V--- -----V-----
3059 // leading broadcast part src shape part
3060 //
3061 // Note that the trailing dims of broadcastShape are exactly the srcShape
3062 // by construction.
3063 // nextSrcShapeDim is used to keep track of where in the permutation the
3064 // "src shape part" occurs.
3065 int64_t nextSrcShapeDim = broadcastedDims.size();
3066 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
3067 if (broadcastedDims.contains(i)) {
3068 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
3069 // bring it to the head of the broadcastShape.
3070 // It will need to be permuted back from `broadcastShape.size() - 1` into
3071 // position `i`.
3072 broadcastShape.push_back(dstShape[i]);
3073 permutation[i] = broadcastShape.size() - 1;
3074 } else {
3075 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
3076 // shape and needs to be permuted into position `i`.
3077 // Don't touch `broadcastShape` here, the whole srcShape will be
3078 // appended after.
3079 permutation[i] = nextSrcShapeDim++;
3080 }
3081 }
3082 // 3.c. Append the srcShape.
3083 llvm::append_range(broadcastShape, srcVectorType.getShape());
3084
3085 // Ensure there are no "dim-1" broadcasts.
3086 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
3087 .empty() &&
3088 "unexpected \"dim-1\" broadcast");
3089
3090 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3091 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
3092 vector::BroadcastableToResult::Success &&
3093 "must be broadcastable");
3094 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3095 // Step 4. If we find any dimension that indeed needs to be permuted,
3096 // immediately return a new vector.transpose.
3097 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3098 if (permutation[i] != i)
3099 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3100 // Otherwise return res.
3101 return res;
3102}
3103
3105 Type srcType, VectorType dstVectorType,
3106 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3107 // Broadcast scalar to vector of the same element type.
3108 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3109 srcType == getElementTypeOrSelf(dstVectorType))
3111 // From now on, only vectors broadcast.
3112 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3113 if (!srcVectorType)
3115
3116 int64_t srcRank = srcVectorType.getRank();
3117 int64_t dstRank = dstVectorType.getRank();
3118 if (srcRank > dstRank)
3120 // Source has an exact match or singleton value for all trailing dimensions
3121 // (all leading dimensions are simply duplicated).
3122 int64_t lead = dstRank - srcRank;
3123 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3124 // Have mismatching dims (in the sense of vector.broadcast semantics) been
3125 // encountered?
3126 bool foundMismatchingDims = false;
3127
3128 // Check fixed-width dims.
3129 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3130 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3131 if (srcDim != 1 && srcDim != dstDim)
3132 foundMismatchingDims = true;
3133
3134 // Check scalable flags.
3135 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3136 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3137 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3138 // 1 -> [N] is fine, everything else should be rejected when mixing
3139 // fixed-width and scalable dims
3140 (srcDimScalableFlag != dstDimScalableFlag &&
3141 (srcDim != 1 || srcDimScalableFlag)))
3142 foundMismatchingDims = true;
3143
3144 if (foundMismatchingDims) {
3145 if (mismatchingDims != nullptr) {
3146 mismatchingDims->first.dim = srcDim;
3147 mismatchingDims->first.isScalable = srcDimScalableFlag;
3148
3149 mismatchingDims->second.dim = dstDim;
3150 mismatchingDims->second.isScalable = dstDimScalableFlag;
3151 }
3153 }
3154 }
3155
3157}
3158
3159LogicalResult BroadcastOp::verify() {
3160 std::pair<VectorDim, VectorDim> mismatchingDims;
3162 getSourceType(), getResultVectorType(), &mismatchingDims);
3164 return success();
3166 return emitOpError("source rank higher than destination rank");
3168 return emitOpError("dimension mismatch (")
3169 << (mismatchingDims.first.isScalable ? "[" : "")
3170 << mismatchingDims.first.dim
3171 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
3172 << (mismatchingDims.second.isScalable ? "[" : "")
3173 << mismatchingDims.second.dim
3174 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
3175 }
3177 return emitOpError("source type is not a vector");
3178 llvm_unreachable("unexpected vector.broadcast op error");
3179}
3180
3181// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
3182// with broadcast's result type and shape_cast only adds or removes ones in the
3183// leading dimensions.
3184static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
3185 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3186 if (!srcShapeCast)
3187 return failure();
3188
3189 VectorType srcType = srcShapeCast.getSourceVectorType();
3190 VectorType destType = broadcastOp.getResultVectorType();
3191 // Check type compatibility.
3192 if (vector::isBroadcastableTo(srcType, destType) !=
3194 return failure();
3195
3196 ArrayRef<int64_t> srcShape = srcType.getShape();
3197 ArrayRef<int64_t> shapecastShape =
3198 srcShapeCast.getResultVectorType().getShape();
3199 // Trailing dimensions should be the same if shape_cast only alters the
3200 // leading dimensions.
3201 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3202 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3203 shapecastShape.take_back(numTrailingDims)))
3204 return failure();
3205
3206 assert(all_of(srcShape.drop_back(numTrailingDims),
3207 [](int64_t E) { return E == 1; }) &&
3208 all_of(shapecastShape.drop_back(numTrailingDims),
3209 [](int64_t E) { return E == 1; }) &&
3210 "ill-formed shape_cast");
3211
3212 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3213 return success();
3214}
3215
3216OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3217 if (getSourceType() == getResultVectorType())
3218 return getSource();
3219 if (succeeded(foldBroadcastOfShapeCast(*this)))
3220 return getResult();
3221
3222 if (!adaptor.getSource())
3223 return {};
3224 auto vectorType = getResultVectorType();
3225 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3226 if (vectorType.getElementType() != attr.getType())
3227 return {};
3228 return DenseElementsAttr::get(vectorType, attr);
3229 }
3230 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3231 if (vectorType.getElementType() != attr.getType())
3232 return {};
3233 return DenseElementsAttr::get(vectorType, attr);
3234 }
3235 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3236 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3237 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
3238 return ub::PoisonAttr::get(getContext());
3239 return {};
3240}
3241
3242namespace {
3243
3244// Fold broadcast1(broadcast2(x)) into broadcast1(x).
3245struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3246 using Base::Base;
3247
3248 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3249 PatternRewriter &rewriter) const override {
3250 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3251 if (!srcBroadcast)
3252 return failure();
3253 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3254 broadcastOp.getResultVectorType(),
3255 srcBroadcast.getSource());
3256 return success();
3257 }
3258};
3259
3260/// Replace `vector.broadcast` with `vector.shape_cast`.
3261///
3262/// BEFORE:
3263/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3264/// AFTER:
3265/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3266///
3267/// The canonical form of vector operations that reshape vectors is shape_cast.
3268struct BroadcastToShapeCast final
3269 : public OpRewritePattern<vector::BroadcastOp> {
3270 using Base::Base;
3271 LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
3272 PatternRewriter &rewriter) const override {
3273
3274 auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
3275 if (!sourceType) {
3276 return rewriter.notifyMatchFailure(
3277 broadcast, "source is a scalar, shape_cast doesn't support scalar");
3278 }
3279
3280 VectorType outType = broadcast.getType();
3281 if (sourceType.getNumElements() != outType.getNumElements()) {
3282 return rewriter.notifyMatchFailure(
3283 broadcast, "broadcast to a greater number of elements");
3284 }
3285
3286 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
3287 broadcast.getSource());
3288 return success();
3289 }
3290};
3291} // namespace
3292
3293void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3294 MLIRContext *context) {
3295 results.add<BroadcastFolder, BroadcastToShapeCast>(context);
3296}
3297
3298//===----------------------------------------------------------------------===//
3299// ShuffleOp
3300//===----------------------------------------------------------------------===//
3301
3302LogicalResult ShuffleOp::verify() {
3303 VectorType resultType = getResultVectorType();
3304 VectorType v1Type = getV1VectorType();
3305 VectorType v2Type = getV2VectorType();
3306 // Verify ranks.
3307 int64_t resRank = resultType.getRank();
3308 int64_t v1Rank = v1Type.getRank();
3309 int64_t v2Rank = v2Type.getRank();
3310 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3311 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3312 if (!wellFormed0DCase && !wellFormedNDCase)
3313 return emitOpError("rank mismatch");
3314
3315 // Verify all but leading dimension sizes.
3316 for (int64_t r = 1; r < v1Rank; ++r) {
3317 int64_t resDim = resultType.getDimSize(r);
3318 int64_t v1Dim = v1Type.getDimSize(r);
3319 int64_t v2Dim = v2Type.getDimSize(r);
3320 if (resDim != v1Dim || v1Dim != v2Dim)
3321 return emitOpError("dimension mismatch");
3322 }
3323 // Verify mask length.
3324 ArrayRef<int64_t> mask = getMask();
3325 int64_t maskLength = mask.size();
3326 if (maskLength <= 0)
3327 return emitOpError("invalid mask length");
3328 if (maskLength != resultType.getDimSize(0))
3329 return emitOpError("mask length mismatch");
3330 // Verify all indices.
3331 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3332 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3333 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3334 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3335 return emitOpError("mask index #") << (idx + 1) << " out of range";
3336 }
3337 return success();
3338}
3339
3340LogicalResult
3341ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3342 ShuffleOp::Adaptor adaptor,
3343 SmallVectorImpl<Type> &inferredReturnTypes) {
3344 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3345 if (!v1Type) {
3346 return emitOptionalError(loc, "expected vector type");
3347 }
3348 auto v1Rank = v1Type.getRank();
3349 // Construct resulting type: leading dimension matches mask
3350 // length, all trailing dimensions match the operands.
3351 SmallVector<int64_t, 4> shape;
3352 shape.reserve(v1Rank);
3353 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3354 // In the 0-D case there is no trailing shape to append.
3355 if (v1Rank > 0)
3356 llvm::append_range(shape, v1Type.getShape().drop_front());
3357 inferredReturnTypes.push_back(
3358 VectorType::get(shape, v1Type.getElementType()));
3359 return success();
3360}
3361
3362template <typename T>
3363static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3364 T expected = begin;
3365 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3366 return value == expected++;
3367 });
3368}
3369
3370/// Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3371/// Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3373 auto v1Type = op.getV1VectorType();
3374 auto v2Type = op.getV2VectorType();
3375 auto mask = op.getMask();
3376 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3377 return op.getV1();
3378 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3379 return op.getV2();
3380 return {};
3381}
3382
3383/// If a shuffle operand is poison, replace all mask indices that reference it
3384/// with kPoisonIndex. This is an in-place fold.
3386 bool isV1Poison = matchPattern(op.getV1(), ub::m_Poison());
3387 bool isV2Poison = matchPattern(op.getV2(), ub::m_Poison());
3388 if (!isV1Poison && !isV2Poison)
3389 return {};
3390
3391 int64_t v1Size = op.getV1VectorType().getDimSize(0);
3392 bool changed = false;
3393 SmallVector<int64_t> newMask = llvm::to_vector(op.getMask());
3394 for (int64_t &idx : newMask) {
3395 if (idx == ShuffleOp::kPoisonIndex)
3396 continue;
3397 if ((isV1Poison && idx < v1Size) || (isV2Poison && idx >= v1Size)) {
3398 idx = ShuffleOp::kPoisonIndex;
3399 changed = true;
3400 }
3401 }
3402
3403 if (!changed)
3404 return {};
3405
3406 op.setMask(newMask);
3407 return op.getResult();
3408}
3409
3410/// Fold shuffle poison, poison -> poison.
3412 Attribute v1Attr,
3413 Attribute v2Attr) {
3414 if (matchPattern(v1Attr, ub::m_Poison()) &&
3415 matchPattern(v2Attr, ub::m_Poison()))
3416 return ub::PoisonAttr::get(context);
3417 return {};
3418}
3419
3420/// Fold a shuffle of constant 1-D inputs by evaluating the mask.
3422 Attribute v2Attr) {
3423 auto v1Type = op.getV1VectorType();
3424 if (v1Type.getRank() != 1)
3425 return {};
3426
3427 bool isV1Poison = matchPattern(v1Attr, ub::m_Poison());
3428 bool isV2Poison = matchPattern(v2Attr, ub::m_Poison());
3429
3430 // Poison input attributes need special handling as they are not
3431 // DenseElementsAttr. If an index is poison, we select the first element of
3432 // the first non-poison input.
3433 SmallVector<Attribute> v1Elements, v2Elements;
3434 Attribute poisonElement;
3435 if (!isV2Poison) {
3436 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3437 if (!v2DenseAttr)
3438 return {};
3439 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3440 poisonElement = v2Elements[0];
3441 }
3442 if (!isV1Poison) {
3443 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3444 if (!v1DenseAttr)
3445 return {};
3446 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3447 poisonElement = v1Elements[0];
3448 }
3449
3450 ArrayRef<int64_t> mask = op.getMask();
3451 SmallVector<Attribute> results;
3452 int64_t v1Size = v1Type.getDimSize(0);
3453 for (int64_t maskIdx : mask) {
3454 Attribute indexedElm;
3455 // TODO: Return a partial poison vector when supported by the UB dialect.
3456 if (maskIdx == ShuffleOp::kPoisonIndex) {
3457 indexedElm = poisonElement;
3458 } else {
3459 if (maskIdx < v1Size)
3460 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3461 else
3462 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3463 }
3464
3465 results.push_back(indexedElm);
3466 }
3467
3468 return DenseElementsAttr::get(op.getResultVectorType(), results);
3469}
3470
3471OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3472 auto v1Type = getV1VectorType();
3473
3474 assert(!v1Type.isScalable() && !getV2VectorType().isScalable() &&
3475 "Vector shuffle does not support scalable vectors");
3476
3477 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3478 // but must be a canonicalization into a vector.broadcast.
3479 if (v1Type.getRank() == 0)
3480 return {};
3481
3482 if (auto res = foldShuffleIdentityMask(*this))
3483 return res;
3484 if (auto res = foldShufflePoisonOperandToMask(*this))
3485 return res;
3486
3487 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3488 if (!v1Attr || !v2Attr)
3489 return {};
3490
3491 if (auto res = foldShufflePoisonInputs(getContext(), v1Attr, v2Attr))
3492 return res;
3493 if (auto res = foldShuffleConstantInputs(*this, v1Attr, v2Attr))
3494 return res;
3495
3496 return {};
3497}
3498
3499namespace {
3500
3501// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3502// to a broadcast.
3503struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3504 using Base::Base;
3505
3506 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3507 PatternRewriter &rewriter) const override {
3508 VectorType v1VectorType = shuffleOp.getV1VectorType();
3509 ArrayRef<int64_t> mask = shuffleOp.getMask();
3510 if (v1VectorType.getRank() > 0)
3511 return failure();
3512 if (mask.size() != 1)
3513 return failure();
3514 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3515 if (mask[0] == 0)
3516 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3517 shuffleOp.getV1());
3518 else
3519 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3520 shuffleOp.getV2());
3521 return success();
3522 }
3523};
3524
3525/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3526/// vector.broadcast with a scalar operand, return the scalar value that is
3527/// splatted. Otherwise return null.
3528///
3529/// Example:
3530///
3531/// scalar_source --> vector.broadcast --> value - return scalar_source
3532static Value getScalarSplatSource(Value value) {
3533 // Block argument:
3534 Operation *defOp = value.getDefiningOp();
3535 if (!defOp)
3536 return {};
3537
3538 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3539
3540 // Not broadcast (and not splat):
3541 if (!broadcast)
3542 return {};
3543
3544 // Broadcast of a vector:
3545 if (isa<VectorType>(broadcast.getSourceType()))
3546 return {};
3547
3548 // Broadcast of a scalar:
3549 return broadcast.getSource();
3550}
3551
3552/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3553class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3554public:
3555 using Base::Base;
3556
3557 LogicalResult matchAndRewrite(ShuffleOp op,
3558 PatternRewriter &rewriter) const override {
3559 Value splat = getScalarSplatSource(op.getV1());
3560 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3561 return failure();
3562
3563 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3564 return success();
3565 }
3566};
3567
3568/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3569/// vector.interleave.
3570class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3571public:
3572 using Base::Base;
3573
3574 LogicalResult matchAndRewrite(ShuffleOp op,
3575 PatternRewriter &rewriter) const override {
3576 VectorType resultType = op.getResultVectorType();
3577 if (resultType.isScalable())
3578 return rewriter.notifyMatchFailure(
3579 op, "ShuffleOp can't represent a scalable interleave");
3580
3581 if (resultType.getRank() != 1)
3582 return rewriter.notifyMatchFailure(
3583 op, "ShuffleOp can't represent an n-D interleave");
3584
3585 VectorType sourceType = op.getV1VectorType();
3586 if (sourceType != op.getV2VectorType() ||
3587 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3588 return rewriter.notifyMatchFailure(
3589 op, "ShuffleOp types don't match an interleave");
3590 }
3591
3592 ArrayRef<int64_t> shuffleMask = op.getMask();
3593 int64_t resultVectorSize = resultType.getNumElements();
3594 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3595 int64_t maskValueA = shuffleMask[i * 2];
3596 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3597 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3598 return rewriter.notifyMatchFailure(op,
3599 "ShuffleOp mask not interleaving");
3600 }
3601
3602 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3603 return success();
3604 }
3605};
3606
3607/// Pattern to replace usused shuffle operands / results with poison.
3608///
3609/// Example Input:
3610/// %r = vector.shuffle %v1, %v2 [2, 3, 3, 3] : vector<2xi32>, vector<2xi32>
3611///
3612/// Example Output:
3613/// %0 = ub.poison : vector<2xi32>
3614/// %r = vector.shuffle %0, %v2 [2, 3, 3, 3] : vector<2xi32>, vector<2xi32>
3615class FoldUnusedShuffleOperand final : public OpRewritePattern<ShuffleOp> {
3616public:
3617 using Base::Base;
3618
3619 LogicalResult matchAndRewrite(ShuffleOp op,
3620 PatternRewriter &rewriter) const override {
3621 // Replace with poison if all mask elements are poison.
3622 if (llvm::all_of(op.getMask(), [](int64_t mask) {
3623 return mask == ShuffleOp::kPoisonIndex;
3624 })) {
3625 rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType());
3626 return success();
3627 }
3628
3629 // Helper function to replace an operand with poison.
3630 auto replaceOperandWithPoison = [&](OpOperand &operand) {
3631 // Do not replace if the operand is already poison.
3632 if (!matchPattern(operand.get(), ub::m_Poison())) {
3633 Value poison = ub::PoisonOp::create(rewriter, op.getLoc(),
3634 operand.get().getType());
3635 rewriter.modifyOpInPlace(op, [&]() { operand.set(poison); });
3636 return success();
3637 }
3638 return failure();
3639 };
3640
3641 // Replace V1 with poison if it is not used.
3642 int64_t leadingV1Size = op.getV1VectorType().getRank() > 0
3643 ? op.getV1VectorType().getDimSize(0)
3644 : 1;
3645 bool isV1Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3646 return mask != ShuffleOp::kPoisonIndex && mask < leadingV1Size;
3647 });
3648 if (!isV1Used && succeeded(replaceOperandWithPoison(op.getV1Mutable())))
3649 return success();
3650
3651 // Replace V2 with poison if it is not used.
3652 bool isV2Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3653 return mask != ShuffleOp::kPoisonIndex && mask >= leadingV1Size;
3654 });
3655 if (!isV2Used && succeeded(replaceOperandWithPoison(op.getV2Mutable())))
3656 return success();
3657
3658 return failure();
3659 }
3660};
3661} // namespace
3662
3663void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3664 MLIRContext *context) {
3665 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp,
3666 FoldUnusedShuffleOperand>(context);
3667}
3668
3669//===----------------------------------------------------------------------===//
3670// InsertOp
3671//===----------------------------------------------------------------------===//
3672
3673void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3674 SetIntRangeFn setResultRanges) {
3675 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3676}
3677
3678void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3679 Value source, Value dest) {
3680 auto vectorTy = cast<VectorType>(dest.getType());
3681 build(builder, result, source, dest,
3682 SmallVector<int64_t>(vectorTy.getRank(), 0));
3683}
3684
3685void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3686 Value source, Value dest, int64_t position) {
3687 build(builder, result, source, dest, ArrayRef<int64_t>{position});
3688}
3689
3690void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3691 Value source, Value dest, OpFoldResult position) {
3692 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3693}
3694
3695void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3696 Value source, Value dest,
3697 ArrayRef<int64_t> position) {
3698 SmallVector<OpFoldResult> posVals;
3699 posVals.reserve(position.size());
3700 llvm::transform(position, std::back_inserter(posVals),
3701 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3702 build(builder, result, source, dest, posVals);
3703}
3704
3705void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3706 Value source, Value dest,
3707 ArrayRef<OpFoldResult> position) {
3708 SmallVector<int64_t> staticPos;
3709 SmallVector<Value> dynamicPos;
3710 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3711 build(builder, result, source, dest, dynamicPos,
3712 builder.getDenseI64ArrayAttr(staticPos));
3713}
3714
3715LogicalResult InsertOp::verify() {
3716 if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3717 if (srcTy.getRank() == 0)
3718 return emitError(
3719 "expected a scalar instead of a 0-d vector as the source operand");
3720
3721 SmallVector<OpFoldResult> position = getMixedPosition();
3722 auto destVectorType = getDestVectorType();
3723 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3724 return emitOpError(
3725 "expected position attribute of rank no greater than dest vector rank");
3726 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3727 if (srcVectorType &&
3728 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3729 static_cast<unsigned>(destVectorType.getRank())))
3730 return emitOpError("expected position attribute rank + source rank to "
3731 "match dest vector rank");
3732 if (!srcVectorType &&
3733 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3734 return emitOpError(
3735 "expected position attribute rank to match the dest vector rank");
3736 for (auto [idx, pos] : llvm::enumerate(position)) {
3737 if (auto attr = dyn_cast<Attribute>(pos)) {
3738 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3739 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3740 destVectorType.getDimSize(idx))) {
3741 return emitOpError("expected position attribute #")
3742 << (idx + 1)
3743 << " to be a non-negative integer smaller than the "
3744 "corresponding "
3745 "dest vector dimension";
3746 }
3747 }
3748 }
3749 return success();
3750}
3751
3752// Calculate the linearized position of the continuous chunk of elements to
3753// insert, based on the shape of the value to insert and the positions to insert
3754// at.
3755static int64_t calculateInsertPosition(VectorType destTy,
3756 ArrayRef<int64_t> positions) {
3757 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3758 assert(positions.size() <= completePositions.size() &&
3759 "positions size must be less than or equal to destTy rank");
3760 copy(positions, completePositions.begin());
3761 return linearize(completePositions, computeStrides(destTy.getShape()));
3762}
3763
3764namespace {
3765
3766// If insertOp is only inserting unit dimensions it can be transformed to a
3767// broadcast.
3768class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3769public:
3770 using Base::Base;
3771
3772 LogicalResult matchAndRewrite(InsertOp insertOp,
3773 PatternRewriter &rewriter) const override {
3774 auto srcVecType =
3775 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3776 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3777 srcVecType.getNumElements())
3778 return failure();
3779 rewriter.replaceOpWithNewOp<BroadcastOp>(
3780 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3781 return success();
3782 }
3783};
3784
3785/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3786class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3787public:
3788 using Base::Base;
3789
3790 LogicalResult matchAndRewrite(InsertOp op,
3791 PatternRewriter &rewriter) const override {
3792
3793 Value splat = getScalarSplatSource(op.getValueToStore());
3794 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3795 return failure();
3796
3797 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3798 return success();
3799 }
3800};
3801
3802/// Pattern to optimize a chain of insertions.
3803///
3804/// This pattern identifies chains of vector.insert operations that:
3805/// 1. Only insert values at static positions.
3806/// 2. Completely initialize all elements in the resulting vector.
3807/// 3. All intermediate insert operations have only one use.
3808///
3809/// When these conditions are met, the entire chain can be replaced with a
3810/// single vector.from_elements operation.
3811///
3812/// To keep this pattern simple, and avoid spending too much time on matching
3813/// fragmented insert chains, this pattern only considers the last insert op in
3814/// the chain.
3815///
3816/// Example transformation:
3817/// %poison = ub.poison : vector<2xi32>
3818/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3819/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3820/// ->
3821/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3822class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3823public:
3824 using Base::Base;
3825 LogicalResult matchAndRewrite(InsertOp op,
3826 PatternRewriter &rewriter) const override {
3827
3828 VectorType destTy = op.getDestVectorType();
3829 if (destTy.isScalable())
3830 return failure();
3831 // Ensure this is the trailing vector.insert op in a chain of inserts.
3832 for (Operation *user : op.getResult().getUsers())
3833 if (auto insertOp = dyn_cast<InsertOp>(user))
3834 if (insertOp.getDest() == op.getResult())
3835 return failure();
3836
3837 InsertOp currentOp = op;
3838 SmallVector<InsertOp> chainInsertOps;
3839 while (currentOp) {
3840 // Check cond 1: Dynamic position is not supported.
3841 if (currentOp.hasDynamicPosition())
3842 return failure();
3843
3844 chainInsertOps.push_back(currentOp);
3845 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3846 // Check cond 3: Intermediate inserts have only one use to avoid an
3847 // explosion of vectors.
3848 if (currentOp && !currentOp->hasOneUse())
3849 return failure();
3850 }
3851
3852 int64_t vectorSize = destTy.getNumElements();
3853 int64_t initializedCount = 0;
3854 SmallVector<bool> initializedDestIdxs(vectorSize, false);
3855 SmallVector<int64_t> pendingInsertPos;
3856 SmallVector<int64_t> pendingInsertSize;
3857 SmallVector<Value> pendingInsertValues;
3858
3859 for (auto insertOp : chainInsertOps) {
3860 // This pattern can do nothing with poison index.
3861 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3862 return failure();
3863
3864 // Calculate the linearized position for inserting elements.
3865 int64_t insertBeginPosition =
3866 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3867
3868 // The valueToStore operand may be a vector or a scalar. Need to handle
3869 // both cases.
3870 int64_t insertSize = 1;
3871 if (auto srcVectorType =
3872 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3873 insertSize = srcVectorType.getNumElements();
3874
3875 assert(insertBeginPosition + insertSize <= vectorSize &&
3876 "insert would overflow the vector");
3877
3878 for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3879 insertBeginPosition + insertSize)) {
3880 if (initializedDestIdxs[index])
3881 continue;
3882 initializedDestIdxs[index] = true;
3883 ++initializedCount;
3884 }
3885
3886 // Defer the creation of ops before we can make sure the pattern can
3887 // succeed.
3888 pendingInsertPos.push_back(insertBeginPosition);
3889 pendingInsertSize.push_back(insertSize);
3890 pendingInsertValues.push_back(insertOp.getValueToStore());
3891
3892 if (initializedCount == vectorSize)
3893 break;
3894 }
3895
3896 // Check cond 2: all positions must be initialized.
3897 if (initializedCount != vectorSize)
3898 return failure();
3899
3900 SmallVector<Value> elements(vectorSize);
3901 for (auto [insertBeginPosition, insertSize, valueToStore] :
3902 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3903 pendingInsertValues))) {
3904 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3905
3906 if (!srcVectorType) {
3907 elements[insertBeginPosition] = valueToStore;
3908 continue;
3909 }
3910
3911 Repeated<Type> elementToInsertTypes(insertSize,
3912 srcVectorType.getElementType());
3913 // Get all elements from the vector in row-major order.
3914 auto elementsToInsert = vector::ToElementsOp::create(
3915 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3916 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3917 elements[insertBeginPosition + linearIdx] =
3918 elementsToInsert.getResult(linearIdx);
3919 }
3920 }
3921
3922 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3923 return success();
3924 }
3925};
3926
3927} // namespace
3928
3929static Attribute
3931 Attribute dstAttr,
3932 int64_t maxVectorSizeFoldThreshold) {
3933 if (insertOp.hasDynamicPosition())
3934 return {};
3935
3936 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3937 if (!denseDst)
3938 return {};
3939
3940 if (!srcAttr) {
3941 return {};
3942 }
3943
3944 VectorType destTy = insertOp.getDestVectorType();
3945 if (destTy.isScalable())
3946 return {};
3947
3948 // Make sure we do not create too many large constants.
3949 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3950 !insertOp->hasOneUse())
3951 return {};
3952
3953 // Bail out on poison indices (kPoisonIndex = -1) to avoid computing an
3954 // invalid (negative) linearized position which would cause UB below.
3955 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3956 return {};
3957
3958 // Calculate the linearized position for inserting elements.
3959 int64_t insertBeginPosition =
3960 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3961 SmallVector<Attribute> insertedValues;
3962 Type destEltType = destTy.getElementType();
3963
3964 /// Converts attribute to the expected type if there's
3965 /// a mismatch.
3966 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3967 for (auto value : denseSource.getValues<Attribute>())
3968 insertedValues.push_back(convertNumericAttr(value, destEltType));
3969 } else {
3970 insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3971 }
3972
3973 auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3974 copy(insertedValues, allValues.begin() + insertBeginPosition);
3975 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3976
3977 return newAttr;
3978}
3979
3980/// Folder to replace the `dest` operand of the insert op with the root dest of
3981/// the insert op use chain.
3982static Value foldInsertUseChain(InsertOp insertOp) {
3983 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3984 if (!destInsert)
3985 return {};
3986
3987 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3988 return {};
3989
3990 insertOp.setOperand(1, destInsert.getDest());
3991 return insertOp.getResult();
3992}
3993
3994void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3995 MLIRContext *context) {
3996 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3997 InsertChainFullyInitialized>(context);
3998}
3999
4000OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
4001 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
4002 // unless the source vector constant has a single use.
4003 constexpr int64_t vectorSizeFoldThreshold = 256;
4004 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
4005 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
4006 // (type mismatch).
4007 if (getNumIndices() == 0 && getValueToStoreType() == getType())
4008 return getValueToStore();
4009 // Fold `arith.constant` indices into the `vector.insert` operation.
4010 // Do not stop here as this fold may enable subsequent folds that require
4011 // constant indices.
4012 SmallVector<Value> operands = {getValueToStore(), getDest()};
4013 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
4014
4015 if (auto res = foldInsertUseChain(*this))
4016 return res;
4017 if (auto res = foldPoisonIndexInsertExtractOp(
4018 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
4019 return res;
4020 if (auto res = foldDenseElementsAttrDestInsertOp(
4021 *this, adaptor.getValueToStore(), adaptor.getDest(),
4022 vectorSizeFoldThreshold)) {
4023 return res;
4024 }
4025
4026 return inplaceFolded;
4027}
4028
4029//===----------------------------------------------------------------------===//
4030// InsertStridedSliceOp
4031//===----------------------------------------------------------------------===//
4032
4033void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4034 Value source, Value dest,
4035 ArrayRef<int64_t> offsets,
4036 ArrayRef<int64_t> strides) {
4037 result.addOperands({source, dest});
4038 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4039 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4040 result.addTypes(dest.getType());
4041 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
4042 offsetsAttr);
4043 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
4044 stridesAttr);
4045}
4046
4047// TODO: Should be moved to Tablegen ConfinedAttr attributes.
4048template <typename OpType>
4049static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
4050 ArrayAttr arrayAttr,
4052 StringRef attrName) {
4053 if (arrayAttr.size() > shape.size())
4054 return op.emitOpError("expected ")
4055 << attrName << " attribute of rank no greater than vector rank";
4056 return success();
4057}
4058
4059// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
4060// interval. If `halfOpen` is true then the admissible interval is [min, max).
4061// Otherwise, the admissible interval is [min, max].
4062template <typename OpType>
4063static LogicalResult
4065 int64_t max, StringRef attrName,
4066 bool halfOpen = true) {
4067 for (auto attr : arrayAttr) {
4068 auto val = llvm::cast<IntegerAttr>(attr).getInt();
4069 auto upper = max;
4070 if (!halfOpen)
4071 upper += 1;
4072 if (val < min || val >= upper)
4073 return op.emitOpError("expected ") << attrName << " to be confined to ["
4074 << min << ", " << upper << ")";
4075 }
4076 return success();
4077}
4078
4079// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
4080// interval. If `halfOpen` is true then the admissible interval is [min, max).
4081// Otherwise, the admissible interval is [min, max].
4082template <typename OpType>
4083static LogicalResult
4085 ArrayRef<int64_t> shape, StringRef attrName,
4086 bool halfOpen = true, int64_t min = 0) {
4087 for (auto [index, attrDimPair] :
4088 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
4089 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
4090 int64_t max = std::get<1>(attrDimPair);
4091 if (!halfOpen)
4092 max += 1;
4093 if (val < min || val >= max)
4094 return op.emitOpError("expected ")
4095 << attrName << " dimension " << index << " to be confined to ["
4096 << min << ", " << max << ")";
4097 }
4098 return success();
4099}
4100
4101// Returns true if, for all indices i = 0..shape.size()-1, val is in the
4102// [min, max} interval:
4103// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
4104// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
4105// the admissible interval is [min, max].
4106template <typename OpType>
4108 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
4109 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
4110 bool halfOpen = true, int64_t min = 1) {
4111 assert(arrayAttr1.size() <= shape.size());
4112 assert(arrayAttr2.size() <= shape.size());
4113 for (auto [index, it] :
4114 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
4115 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
4116 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
4117 int64_t max = std::get<2>(it);
4118 if (!halfOpen)
4119 max += 1;
4120 if (val1 + val2 < 0 || val1 + val2 >= max)
4121 return op.emitOpError("expected sum(")
4122 << attrName1 << ", " << attrName2 << ") dimension " << index
4123 << " to be confined to [" << min << ", " << max << ")";
4124 }
4125 return success();
4126}
4127
4129 MLIRContext *context) {
4130 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
4131 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
4132 });
4133 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
4134}
4135
4136LogicalResult InsertStridedSliceOp::verify() {
4137 auto sourceVectorType = getSourceVectorType();
4138 auto destVectorType = getDestVectorType();
4139 auto offsets = getOffsetsAttr();
4140 auto strides = getStridesAttr();
4141 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
4142 return emitOpError(
4143 "expected offsets of same size as destination vector rank");
4144 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
4145 return emitOpError("expected strides of same size as source vector rank");
4146 if (sourceVectorType.getRank() > destVectorType.getRank())
4147 return emitOpError(
4148 "expected source rank to be no greater than destination rank");
4149
4150 auto sourceShape = sourceVectorType.getShape();
4151 auto destShape = destVectorType.getShape();
4152 SmallVector<int64_t, 4> sourceShapeAsDestShape(
4153 destShape.size() - sourceShape.size(), 0);
4154 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
4155 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
4156 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
4157 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
4158 offName)) ||
4159 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4160 /*max=*/1, stridesName,
4161 /*halfOpen=*/false)) ||
4163 *this, offsets,
4164 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
4165 offName, "source vector shape",
4166 /*halfOpen=*/false, /*min=*/1)))
4167 return failure();
4168
4169 unsigned rankDiff = destShape.size() - sourceShape.size();
4170 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
4171 if (sourceVectorType.getScalableDims()[idx] !=
4172 destVectorType.getScalableDims()[idx + rankDiff]) {
4173 return emitOpError("mismatching scalable flags (at source vector idx=")
4174 << idx << ")";
4175 }
4176 if (sourceVectorType.getScalableDims()[idx]) {
4177 auto sourceSize = sourceShape[idx];
4178 auto destSize = destShape[idx + rankDiff];
4179 if (sourceSize != destSize) {
4180 return emitOpError("expected size at idx=")
4181 << idx
4182 << (" to match the corresponding base size from the input "
4183 "vector (")
4184 << sourceSize << (" vs ") << destSize << (")");
4185 }
4186 }
4187 }
4188
4189 return success();
4190}
4191
4192namespace {
4193/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
4194class FoldInsertStridedSliceSplat final
4195 : public OpRewritePattern<InsertStridedSliceOp> {
4196public:
4197 using Base::Base;
4198
4199 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4200 PatternRewriter &rewriter) const override {
4201
4202 auto dst = insertStridedSliceOp.getDest();
4203 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4204 if (!splat || getScalarSplatSource(dst) != splat)
4205 return failure();
4206
4207 rewriter.replaceOp(insertStridedSliceOp, dst);
4208 return success();
4209 }
4210};
4211
4212/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
4213/// to dst.
4214class FoldInsertStridedSliceOfExtract final
4215 : public OpRewritePattern<InsertStridedSliceOp> {
4216public:
4217 using Base::Base;
4218
4219 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4220 PatternRewriter &rewriter) const override {
4221 auto extractStridedSliceOp =
4222 insertStridedSliceOp.getValueToStore()
4223 .getDefiningOp<vector::ExtractStridedSliceOp>();
4224
4225 if (!extractStridedSliceOp)
4226 return failure();
4227
4228 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4229 return failure();
4230
4231 // Check if have the same strides and offsets.
4232 if (extractStridedSliceOp.getStrides() !=
4233 insertStridedSliceOp.getStrides() ||
4234 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4235 return failure();
4236
4237 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4238 return success();
4239 }
4240};
4241
4242// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
4243// ConstantOp.
4244class InsertStridedSliceConstantFolder final
4245 : public OpRewritePattern<InsertStridedSliceOp> {
4246public:
4247 using Base::Base;
4248
4249 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
4250 // unless the source vector constant has a single use.
4251 static constexpr int64_t vectorSizeFoldThreshold = 256;
4252
4253 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4254 PatternRewriter &rewriter) const override {
4255 // Return if 'InsertOp' operand is not defined by a compatible vector
4256 // ConstantOp.
4257 TypedValue<VectorType> destVector = op.getDest();
4258 Attribute vectorDestCst;
4259 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
4260 return failure();
4261
4262 VectorType destTy = destVector.getType();
4263 if (destTy.isScalable())
4264 return failure();
4265
4266 // Make sure we do not create too many large constants.
4267 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4268 !destVector.hasOneUse())
4269 return failure();
4270
4271 TypedValue<VectorType> sourceValue = op.getValueToStore();
4272 Attribute sourceCst;
4273 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
4274 return failure();
4275
4276 // TODO: Support poison.
4277 if (matchPattern(vectorDestCst, ub::m_Poison()) ||
4278 matchPattern(sourceCst, ub::m_Poison()))
4279 return failure();
4280
4281 // TODO: Handle non-unit strides when they become available.
4282 if (op.hasNonUnitStrides())
4283 return failure();
4284
4285 VectorType sliceVecTy = sourceValue.getType();
4286 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4287 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4288 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
4289 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
4290
4291 // Calcualte the destination element indices by enumerating all slice
4292 // positions within the destination and linearizing them. The enumeration
4293 // order is lexicographic which yields a sequence of monotonically
4294 // increasing linearized position indices.
4295 // Because the destination may have higher dimensionality then the slice,
4296 // we keep track of two overlapping sets of positions and offsets.
4297 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4298 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4299 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4300 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4301 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4302 MutableArrayRef<int64_t> currSlicePosition(
4303 currDestPosition.begin() + rankDifference, currDestPosition.end());
4304 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4305 offsets.end());
4306 do {
4307 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
4308 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
4309 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4310 "Invalid slice element");
4311 newValues[linearizedPosition] = *sliceValuesIt;
4312 ++sliceValuesIt;
4313 } while (succeeded(
4314 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
4315
4316 auto newAttr = DenseElementsAttr::get(destTy, newValues);
4317 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
4318 return success();
4319 }
4320};
4321
4322} // namespace
4323
4324void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4325 RewritePatternSet &results, MLIRContext *context) {
4326 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4327 InsertStridedSliceConstantFolder>(context);
4328}
4329
4330OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4331 if (getSourceVectorType() == getDestVectorType())
4332 return getValueToStore();
4333 return {};
4334}
4335
4336//===----------------------------------------------------------------------===//
4337// OuterProductOp
4338//===----------------------------------------------------------------------===//
4339
4340/// Build an op without mask, use the type of `acc` as the return type.
4341void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4342 Value lhs, Value rhs, Value acc) {
4343 result.addOperands({lhs, rhs, acc});
4344 result.addTypes(acc.getType());
4345}
4346
4347void OuterProductOp::print(OpAsmPrinter &p) {
4348 p << " " << getLhs() << ", " << getRhs();
4349 if (getAcc()) {
4350 p << ", " << getAcc();
4351 p.printOptionalAttrDict((*this)->getAttrs());
4352 }
4353 p << " : " << getLhs().getType() << ", " << getRhs().getType();
4354}
4355
4356ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4357 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4358 Type tLHS, tRHS;
4359 if (parser.parseOperandList(operandsInfo) ||
4360 parser.parseOptionalAttrDict(result.attributes) ||
4361 parser.parseColonType(tLHS) || parser.parseComma() ||
4362 parser.parseType(tRHS))
4363 return failure();
4364 if (operandsInfo.size() < 2)
4365 return parser.emitError(parser.getNameLoc(),
4366 "expected at least 2 operands");
4367 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4368 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4369 if (!vLHS)
4370 return parser.emitError(parser.getNameLoc(),
4371 "expected vector type for operand #1");
4372
4373 VectorType resType;
4374 if (vRHS) {
4375 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4376 vRHS.getScalableDims()[0]};
4377 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4378 vLHS.getElementType(), scalableDimsRes);
4379 } else {
4380 // Scalar RHS operand
4381 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4382 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4383 scalableDimsRes);
4384 }
4385
4386 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4387 result.attributes.append(
4388 OuterProductOp::getKindAttrName(result.name),
4389 CombiningKindAttr::get(result.getContext(),
4390 OuterProductOp::getDefaultKind()));
4391 }
4392
4393 return failure(
4394 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4395 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4396 (operandsInfo.size() > 2 &&
4397 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4398 parser.addTypeToList(resType, result.types));
4399}
4400
4401LogicalResult OuterProductOp::verify() {
4402 Type tRHS = getOperandTypeRHS();
4403 VectorType vLHS = getOperandVectorTypeLHS(),
4404 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4405 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4406
4407 if (vLHS.getRank() != 1)
4408 return emitOpError("expected 1-d vector for operand #1");
4409
4410 if (vRHS) {
4411 // Proper OUTER operation.
4412 if (vRHS.getRank() != 1)
4413 return emitOpError("expected 1-d vector for operand #2");
4414 if (vRES.getRank() != 2)
4415 return emitOpError("expected 2-d vector result");
4416 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4417 return emitOpError("expected #1 operand dim to match result dim #1");
4418 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4419 return emitOpError("expected #2 operand dim to match result dim #2");
4420 if (vLHS.isScalable() && !vRHS.isScalable()) {
4421 // This restriction reflects what's currently supported in terms of
4422 // scalable vectors. However, we could relax this if there's a use case.
4423 return emitOpError(
4424 "expected either both or only #2 operand dim to be scalable");
4425 }
4426 } else {
4427 // An AXPY operation.
4428 if (vRES.getRank() != 1)
4429 return emitOpError("expected 1-d vector result");
4430 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4431 return emitOpError("expected #1 operand dim to match result dim #1");
4432 }
4433
4434 if (vACC && vACC != vRES)
4435 return emitOpError("expected operand #3 of same type as result type");
4436
4437 if (!getKindAttr()) {
4438 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
4439 "'vector.kind<add>')");
4440 }
4441
4442 // Verify supported combining kind.
4443 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4444 return emitOpError("unsupported outerproduct type");
4445
4446 return success();
4447}
4448
4449// MaskableOpInterface methods.
4450
4451/// Returns the mask type expected by this operation. Mostly used for
4452/// verification purposes. It requires the operation to be vectorized."
4453Type OuterProductOp::getExpectedMaskType() {
4454 auto vecType = this->getResultVectorType();
4455 return VectorType::get(vecType.getShape(),
4456 IntegerType::get(vecType.getContext(), /*width=*/1),
4457 vecType.getScalableDims());
4458}
4459
4460//===----------------------------------------------------------------------===//
4461// ExtractStridedSliceOp
4462//===----------------------------------------------------------------------===//
4463
4464// Inference works as follows:
4465// 1. Add 'sizes' from prefix of dims in 'offsets'.
4466// 2. Add sizes from 'vectorType' for remaining dims.
4467// Scalable flags are inherited from 'vectorType'.
4468static Type inferStridedSliceOpResultType(VectorType vectorType,
4469 ArrayAttr offsets, ArrayAttr sizes,
4470 ArrayAttr strides) {
4471 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4473 shape.reserve(vectorType.getRank());
4474 unsigned idx = 0;
4475 for (unsigned e = offsets.size(); idx < e; ++idx)
4476 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4477 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4478 shape.push_back(vectorType.getShape()[idx]);
4479
4480 return VectorType::get(shape, vectorType.getElementType(),
4481 vectorType.getScalableDims());
4482}
4483
4484void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4485 Value source, ArrayRef<int64_t> offsets,
4486 ArrayRef<int64_t> sizes,
4487 ArrayRef<int64_t> strides) {
4488 result.addOperands(source);
4489 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4490 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4491 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4492 result.addTypes(
4493 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4494 offsetsAttr, sizesAttr, stridesAttr));
4495 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4496 offsetsAttr);
4497 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4498 sizesAttr);
4499 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4500 stridesAttr);
4501}
4502
4503LogicalResult ExtractStridedSliceOp::verify() {
4504 auto type = getSourceVectorType();
4505 auto offsets = getOffsetsAttr();
4506 auto sizes = getSizesAttr();
4507 auto strides = getStridesAttr();
4508 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4509 return emitOpError(
4510 "expected offsets, sizes and strides attributes of same size");
4511
4512 auto shape = type.getShape();
4513 auto offName = getOffsetsAttrName();
4514 auto sizesName = getSizesAttrName();
4515 auto stridesName = getStridesAttrName();
4516 if (failed(
4517 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4518 failed(
4519 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4520 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4521 stridesName)) ||
4522 failed(
4523 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4524 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4525 /*halfOpen=*/false,
4526 /*min=*/1)) ||
4527 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4528 /*max=*/1, stridesName,
4529 /*halfOpen=*/false)) ||
4530 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4531 shape, offName, sizesName,
4532 /*halfOpen=*/false)))
4533 return failure();
4534
4535 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4536 offsets, sizes, strides);
4537 if (getResult().getType() != resultType)
4538 return emitOpError("expected result type to be ") << resultType;
4539
4540 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4541 if (type.getScalableDims()[idx]) {
4542 auto inputDim = type.getShape()[idx];
4543 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4544 if (inputDim != inputSize)
4545 return emitOpError("expected size at idx=")
4546 << idx
4547 << (" to match the corresponding base size from the input "
4548 "vector (")
4549 << inputSize << (" vs ") << inputDim << (")");
4550 }
4551 }
4552
4553 return success();
4554}
4555
4556// When the source of ExtractStrided comes from a chain of InsertStrided ops try
4557// to use the source of the InsertStrided ops if we can detect that the
4558// extracted vector is a subset of one of the vector inserted.
4559static LogicalResult
4560foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4561 // Helper to extract integer out of ArrayAttr.
4562 auto getElement = [](ArrayAttr array, int idx) {
4563 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4564 };
4565 ArrayAttr extractOffsets = op.getOffsets();
4566 ArrayAttr extractStrides = op.getStrides();
4567 ArrayAttr extractSizes = op.getSizes();
4568 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4569 while (insertOp) {
4570 if (op.getSourceVectorType().getRank() !=
4571 insertOp.getSourceVectorType().getRank())
4572 return failure();
4573 ArrayAttr insertOffsets = insertOp.getOffsets();
4574 ArrayAttr insertStrides = insertOp.getStrides();
4575 // If the rank of extract is greater than the rank of insert, we are likely
4576 // extracting a partial chunk of the vector inserted.
4577 if (extractOffsets.size() > insertOffsets.size())
4578 return failure();
4579 bool patialoverlap = false;
4580 bool disjoint = false;
4581 SmallVector<int64_t, 4> offsetDiffs;
4582 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4583 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4584 return failure();
4585 int64_t start = getElement(insertOffsets, dim);
4586 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4587 int64_t offset = getElement(extractOffsets, dim);
4588 int64_t size = getElement(extractSizes, dim);
4589 // Check if the start of the extract offset is in the interval inserted.
4590 if (start <= offset && offset < end) {
4591 // If the extract interval overlaps but is not fully included we may
4592 // have a partial overlap that will prevent any folding.
4593 if (offset + size > end)
4594 patialoverlap = true;
4595 offsetDiffs.push_back(offset - start);
4596 continue;
4597 }
4598 disjoint = true;
4599 break;
4600 }
4601 // The extract element chunk is a subset of the insert element.
4602 if (!disjoint && !patialoverlap) {
4603 op.setOperand(insertOp.getValueToStore());
4604 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4605 OpBuilder b(op.getContext());
4606 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4607 return success();
4608 }
4609 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4610 // in the insert chain.
4611 if (disjoint)
4612 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4613 else {
4614 // The extracted vector partially overlap the inserted vector, we cannot
4615 // fold.
4616 return failure();
4617 }
4618 }
4619 return failure();
4620}
4621
4622// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4623static OpFoldResult
4625 Attribute foldInput) {
4626
4627 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4628 if (!dense)
4629 return {};
4630
4631 // TODO: Handle non-unit strides when they become available.
4632 if (op.hasNonUnitStrides())
4633 return {};
4634
4635 VectorType sourceVecTy = op.getSourceVectorType();
4636 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4637 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4638
4639 VectorType sliceVecTy = op.getType();
4640 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4641 int64_t rank = sliceVecTy.getRank();
4642
4643 // Expand offsets and sizes to match the vector rank.
4644 SmallVector<int64_t, 4> offsets(rank, 0);
4645 copy(getI64SubArray(op.getOffsets()), offsets.begin());
4646
4647 SmallVector<int64_t, 4> sizes(sourceShape);
4648 copy(getI64SubArray(op.getSizes()), sizes.begin());
4649
4650 // Calculate the slice elements by enumerating all slice positions and
4651 // linearizing them. The enumeration order is lexicographic which yields a
4652 // sequence of monotonically increasing linearized position indices.
4653 const auto denseValuesBegin = dense.value_begin<Attribute>();
4654 SmallVector<Attribute> sliceValues;
4655 sliceValues.reserve(sliceVecTy.getNumElements());
4656 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4657 do {
4658 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4659 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4660 "Invalid index");
4661 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4662 } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4663
4664 assert(static_cast<int64_t>(sliceValues.size()) ==
4665 sliceVecTy.getNumElements() &&
4666 "Invalid number of slice elements");
4667 return DenseElementsAttr::get(sliceVecTy, sliceValues);
4668}
4669
4670OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4671 if (getSourceVectorType() == getResult().getType())
4672 return getSource();
4673 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4674 return getResult();
4675
4676 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4677 if (auto splat =
4678 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4679 return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4680
4681 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4682 return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4683}
4684
4685void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4686 populateFromInt64AttrArray(getOffsets(), results);
4687}
4688
4689namespace {
4690
4691// Pattern to rewrite nested ExtractStridedSliceOp into a single one.
4692//
4693// Example:
4694//
4695// %0 = vector.extract_strided_slice %arg0
4696// {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]}
4697// : vector<4x8x16xf32> to vector<3x4x16xf32>
4698// %1 = vector.extract_strided_slice %0
4699// {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
4700// : vector<3x4x16xf32> to vector<2x2x16xf32>
4701//
4702// to
4703//
4704// %1 = vector.extract_strided_slice %arg0
4705// {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
4706// : vector<4x8x16xf32> to vector<2x2x16xf32>
4707class StridedSliceFolder final
4708 : public OpRewritePattern<ExtractStridedSliceOp> {
4709public:
4710 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4711
4712 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4713 PatternRewriter &rewriter) const override {
4714 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4715 if (!firstOp)
4716 return failure();
4717
4718 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4719 return failure();
4720
4721 SmallVector<int64_t> firstOffsets = getI64SubArray(firstOp.getOffsets());
4722 SmallVector<int64_t> firstSizes = getI64SubArray(firstOp.getSizes());
4723 SmallVector<int64_t> secondOffsets = getI64SubArray(secondOp.getOffsets());
4724 SmallVector<int64_t> secondSizes = getI64SubArray(secondOp.getSizes());
4725
4726 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4727 SmallVector<int64_t> combinedOffsets(newRank, 0);
4728 SmallVector<int64_t> combinedSizes(newRank);
4729 ArrayRef<int64_t> firstSourceShape =
4730 firstOp.getSourceVectorType().getShape();
4731 for (unsigned i = 0; i < newRank; ++i) {
4732 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4733 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4734 combinedOffsets[i] = off1 + off2;
4735
4736 if (i < secondSizes.size()) {
4737 combinedSizes[i] = secondSizes[i];
4738 } else if (i < firstSizes.size()) {
4739 combinedSizes[i] = firstSizes[i];
4740 } else {
4741 combinedSizes[i] = firstSourceShape[i];
4742 }
4743 }
4744
4745 SmallVector<int64_t> combinedStrides(newRank, 1);
4746 rewriter.replaceOpWithNewOp<ExtractStridedSliceOp>(
4747 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4748 combinedStrides);
4749 return success();
4750 }
4751};
4752
4753// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4754// CreateMaskOp.
4755//
4756// Example:
4757//
4758// %mask = vector.create_mask %ub : vector<16xi1>
4759// %slice = vector.extract_strided_slice [%offset] [8] [1]
4760//
4761// to
4762//
4763// %new_ub = arith.subi %ub, %offset
4764// %mask = vector.create_mask %new_ub : vector<8xi1>
4765class StridedSliceCreateMaskFolder final
4766 : public OpRewritePattern<ExtractStridedSliceOp> {
4767 using Base::Base;
4768
4769public:
4770 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4771 PatternRewriter &rewriter) const override {
4772 Location loc = extractStridedSliceOp.getLoc();
4773 // Return if 'extractStridedSliceOp' operand is not defined by a
4774 // CreateMaskOp.
4775 auto createMaskOp =
4776 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4777 if (!createMaskOp)
4778 return failure();
4779 // Return if 'extractStridedSliceOp' has non-unit strides.
4780 if (extractStridedSliceOp.hasNonUnitStrides())
4781 return failure();
4782 // Gather constant mask dimension sizes.
4783 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4784 // Gather strided slice offsets and sizes.
4785 SmallVector<int64_t> sliceOffsets;
4786 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4787 sliceOffsets);
4788 SmallVector<int64_t> sliceSizes;
4789 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4790
4791 // Compute slice of vector mask region.
4792 SmallVector<Value> sliceMaskDimSizes;
4793 sliceMaskDimSizes.reserve(maskDimSizes.size());
4794 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4795 // only iterate on the leading dim sizes. The tail accounts for the
4796 // remaining dim sizes.
4797 for (auto [maskDimSize, sliceOffset, sliceSize] :
4798 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4799 // No need to clamp on min/max values, because create_mask has clamping
4800 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4801 // greater than the vector dim size.
4802 IntegerAttr offsetAttr =
4803 rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4804 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4805 Value sliceMaskDimSize =
4806 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4807 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4808 }
4809 // Add unchanged dimensions.
4810 llvm::append_range(
4811 sliceMaskDimSizes,
4812 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4813 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4814 // region.
4815 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4816 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4817 sliceMaskDimSizes);
4818 return success();
4819 }
4820};
4821
4822// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4823// ConstantMaskOp.
4824class StridedSliceConstantMaskFolder final
4825 : public OpRewritePattern<ExtractStridedSliceOp> {
4826public:
4827 using Base::Base;
4828
4829 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4830 PatternRewriter &rewriter) const override {
4831 // Return if 'extractStridedSliceOp' operand is not defined by a
4832 // ConstantMaskOp.
4833 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4834 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4835 if (!constantMaskOp)
4836 return failure();
4837 // Return if 'extractStridedSliceOp' has non-unit strides.
4838 if (extractStridedSliceOp.hasNonUnitStrides())
4839 return failure();
4840 // Gather constant mask dimension sizes.
4841 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4842 // Gather strided slice offsets and sizes.
4843 SmallVector<int64_t> sliceOffsets;
4844 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4845 sliceOffsets);
4846 SmallVector<int64_t> sliceSizes;
4847 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4848
4849 // Compute slice of vector mask region.
4850 SmallVector<int64_t> sliceMaskDimSizes;
4851 sliceMaskDimSizes.reserve(maskDimSizes.size());
4852 for (auto [maskDimSize, sliceOffset, sliceSize] :
4853 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4854 int64_t sliceMaskDimSize = std::max(
4855 static_cast<int64_t>(0),
4856 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4857 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4858 }
4859 // Add unchanged dimensions.
4860 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4861 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4862 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4863 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4864 // region is a conjunction of mask dim intervals).
4865 if (llvm::is_contained(sliceMaskDimSizes, 0))
4866 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4867
4868 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4869 // region.
4870 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4871 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4872 sliceMaskDimSizes);
4873 return success();
4874 }
4875};
4876
4877// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4878// BroadcastOp(ExtractStrideSliceOp).
4879class StridedSliceBroadcast final
4880 : public OpRewritePattern<ExtractStridedSliceOp> {
4881public:
4882 using Base::Base;
4883
4884 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4885 PatternRewriter &rewriter) const override {
4886 auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4887 if (!broadcast)
4888 return failure();
4889 auto srcVecType =
4890 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4891 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4892 auto dstVecType = llvm::cast<VectorType>(op.getType());
4893 unsigned dstRank = dstVecType.getRank();
4894 unsigned rankDiff = dstRank - srcRank;
4895 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4896 // (n -> m with n > m). If they are originally both broadcasted *and*
4897 // sliced, this can be simplified to just broadcasting.
4898 bool needsSlice = false;
4899 for (unsigned i = 0; i < srcRank; i++) {
4900 if (srcVecType.getDimSize(i) != 1 &&
4901 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4902 needsSlice = true;
4903 break;
4904 }
4905 }
4906 Value source = broadcast.getSource();
4907 if (needsSlice) {
4908 SmallVector<int64_t> offsets =
4909 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4910 SmallVector<int64_t> sizes =
4911 getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4912 for (unsigned i = 0; i < srcRank; i++) {
4913 if (srcVecType.getDimSize(i) == 1) {
4914 // In case this dimension was broadcasted *and* sliced, the offset
4915 // and size need to be updated now that there is no broadcast before
4916 // the slice.
4917 offsets[i] = 0;
4918 sizes[i] = 1;
4919 }
4920 }
4921 source = ExtractStridedSliceOp::create(
4922 rewriter, op->getLoc(), source, offsets, sizes,
4923 getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4924 }
4925 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4926 return success();
4927 }
4928};
4929
4930/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4931class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4932public:
4933 using Base::Base;
4934
4935 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4936 PatternRewriter &rewriter) const override {
4937
4938 Value splat = getScalarSplatSource(op.getSource());
4939 if (!splat)
4940 return failure();
4941 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4942 return success();
4943 }
4944};
4945
4946/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4947/// slice is contiguous, into extract and shape_cast.
4948///
4949/// Example:
4950/// Before:
4951/// %1 = vector.extract_strided_slice %arg0 {
4952/// offsets = [0, 0, 0, 0, 0],
4953/// sizes = [1, 1, 1, 1, 8],
4954/// strides = [1, 1, 1, 1, 1]
4955/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4956/// After:
4957/// %0 = vector.extract %arg0[0, 0, 0, 0]
4958/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4959/// %1 = vector.shape_cast %0
4960/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4961///
4962class ContiguousExtractStridedSliceToExtract final
4963 : public OpRewritePattern<ExtractStridedSliceOp> {
4964public:
4965 using Base::Base;
4966
4967 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4968 PatternRewriter &rewriter) const override {
4969 if (op.hasNonUnitStrides())
4970 return failure();
4971 Value source = op.getOperand();
4972 auto sourceType = cast<VectorType>(source.getType());
4973 if (sourceType.isScalable() || sourceType.getRank() == 0)
4974 return failure();
4975
4976 // Compute the number of offsets to pass to ExtractOp::build. That is the
4977 // difference between the source rank and the desired slice rank. We walk
4978 // the dimensions from innermost out, and stop when the next slice dimension
4979 // is not full-size.
4980 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4981 int numOffsets;
4982 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4983 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4984 break;
4985 }
4986
4987 // If the created extract op would have no offsets, then this whole
4988 // extract_strided_slice is the identity and should have been handled by
4989 // other canonicalizations.
4990 if (numOffsets == 0)
4991 return failure();
4992
4993 // If not even the inner-most dimension is full-size, this op can't be
4994 // rewritten as an ExtractOp.
4995 if (numOffsets == sourceType.getRank() &&
4996 static_cast<int>(sizes.size()) == sourceType.getRank())
4997 return failure();
4998
4999 // The outer dimensions must have unit size.
5000 for (int i = 0; i < numOffsets; ++i) {
5001 if (sizes[i] != 1)
5002 return failure();
5003 }
5004
5005 // Avoid generating slices that have leading unit dimensions. The shape_cast
5006 // op that we create below would take bad generic fallback patterns
5007 // (ShapeCastOpRewritePattern).
5008 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
5009 sizes[numOffsets] == 1) {
5010 ++numOffsets;
5011 }
5012
5013 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
5014 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
5015 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
5016 extractOffsets);
5017 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
5018 return success();
5019 }
5020};
5021
5022} // namespace
5023
5024void ExtractStridedSliceOp::getCanonicalizationPatterns(
5025 RewritePatternSet &results, MLIRContext *context) {
5026 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
5027 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
5028 results.add<StridedSliceFolder, StridedSliceCreateMaskFolder,
5029 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
5030 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
5031 context);
5032}
5033
5034//===----------------------------------------------------------------------===//
5035// TransferReadOp
5036//===----------------------------------------------------------------------===//
5037
5038/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
5039/// If `padding` is null, a poison value is used.
5040void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5041 VectorType vectorType, Value source,
5042 ValueRange indices, std::optional<Value> padding,
5043 AffineMapAttr permutationMapAttr,
5044 /*optional*/ ArrayAttr inBoundsAttr) {
5045
5046 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5047 if (!padding)
5048 padding = ub::PoisonOp::create(builder, result.location, elemType);
5049 // Delegate to the most general builder (see
5050 // `mlir/Dialect/Vector/IR/VectorOps.cpp.inc`)
5051 build(builder, result, vectorType, source, indices, permutationMapAttr,
5052 *padding, /*mask=*/Value(), inBoundsAttr);
5053}
5054
5055/// 2. Builder that sets padding to zero and an empty mask (variant without
5056/// attrs).
5057/// If `padding` is null, a poison value is used.
5058/// If `permutationMap` is null, a minor identity map is used.
5059/// If `inBounds` is null, an empty mask is used.
5060void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5061 VectorType vectorType, Value source,
5062 ValueRange indices, std::optional<Value> padding,
5063 AffineMap permutationMap,
5064 std::optional<ArrayRef<bool>> inBounds) {
5065 if (!permutationMap)
5066 permutationMap = getTransferMinorIdentityMap(
5067 llvm::cast<ShapedType>(source.getType()), vectorType);
5068 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5069 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5070 ? builder.getBoolArrayAttr(inBounds.value())
5071 : builder.getBoolArrayAttr(
5072 SmallVector<bool>(vectorType.getRank(), false));
5073 // Delegate to Builder 1
5074 build(builder, result, vectorType, source, indices, padding,
5075 permutationMapAttr, inBoundsAttr);
5076}
5077
5078/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
5079/// If `padding` is null, a poison value is used.
5080/// If `inBounds` is null, an empty mask is used.
5081void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5082 VectorType vectorType, Value source,
5083 ValueRange indices, std::optional<Value> padding,
5084 std::optional<ArrayRef<bool>> inBounds) {
5085 // Delegate to Builder 2
5086 build(builder, result, vectorType, source, indices, padding,
5087 /*permutationMap=*/AffineMap(), inBounds);
5088}
5089
5090template <typename EmitFun>
5091static LogicalResult verifyPermutationMap(AffineMap permutationMap,
5092 EmitFun emitOpError) {
5093 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
5094 for (auto expr : permutationMap.getResults()) {
5095 auto dim = dyn_cast<AffineDimExpr>(expr);
5096 auto zero = dyn_cast<AffineConstantExpr>(expr);
5097 if (zero) {
5098 if (zero.getValue() != 0) {
5099 return emitOpError(
5100 "requires a projected permutation_map (at most one dim or the zero "
5101 "constant can appear in each result)");
5102 }
5103 continue;
5104 }
5105 if (!dim) {
5106 return emitOpError("requires a projected permutation_map (at most one "
5107 "dim or the zero constant can appear in each result)");
5108 }
5109 if (seen[dim.getPosition()]) {
5110 return emitOpError(
5111 "requires a permutation_map that is a permutation (found one dim "
5112 "used more than once)");
5113 }
5114 seen[dim.getPosition()] = true;
5115 }
5116 return success();
5117}
5118
5119static LogicalResult
5120verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
5121 VectorType vectorType, VectorType maskType,
5122 VectorType inferredMaskType, AffineMap permutationMap,
5123 ArrayAttr inBounds) {
5124 if (op->hasAttr("masked")) {
5125 return op->emitOpError("masked attribute has been removed. "
5126 "Use in_bounds instead.");
5127 }
5128
5129 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
5130 return op->emitOpError(
5131 "requires source to be a memref or ranked tensor type");
5132
5133 auto elementType = shapedType.getElementType();
5134 DataLayout dataLayout = DataLayout::closest(op);
5135 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
5136 // Memref or tensor has vector element type.
5137 unsigned sourceVecSize =
5138 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
5139 vectorElementType.getShape().back();
5140 unsigned resultVecSize =
5141 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
5142 vectorType.getShape().back();
5143 if (resultVecSize % sourceVecSize != 0)
5144 return op->emitOpError(
5145 "requires the bitwidth of the minor 1-D vector to be an integral "
5146 "multiple of the bitwidth of the minor 1-D vector of the source");
5147
5148 unsigned sourceVecEltRank = vectorElementType.getRank();
5149 unsigned resultVecRank = vectorType.getRank();
5150 if (sourceVecEltRank > resultVecRank)
5151 return op->emitOpError(
5152 "requires source vector element and vector result ranks to match.");
5153 unsigned rankOffset = resultVecRank - sourceVecEltRank;
5154 // Check that permutation map results match 'rankOffset' of vector type.
5155 if (permutationMap.getNumResults() != rankOffset)
5156 return op->emitOpError("requires a permutation_map with result dims of "
5157 "the same rank as the vector type");
5158
5159 if (maskType)
5160 return op->emitOpError("does not support masks with vector element type");
5161 } else {
5162 // Memref or tensor has scalar element type.
5163 unsigned minorSize =
5164 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
5165 unsigned resultVecSize =
5166 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
5167 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
5168 return op->emitOpError(
5169 "requires the bitwidth of the minor 1-D vector to be an integral "
5170 "multiple of the bitwidth of the source element type");
5171
5172 // Check that permutation map results match rank of vector type.
5173 if (permutationMap.getNumResults() != vectorType.getRank())
5174 return op->emitOpError("requires a permutation_map with result dims of "
5175 "the same rank as the vector type");
5176 }
5177
5178 if (permutationMap.getNumSymbols() != 0)
5179 return op->emitOpError("requires permutation_map without symbols");
5180
5181 if (permutationMap.getNumInputs() != shapedType.getRank())
5182 return op->emitOpError("requires a permutation_map with input dims of the "
5183 "same rank as the source type");
5184
5185 if (maskType && maskType != inferredMaskType)
5186 return op->emitOpError("inferred mask type (")
5187 << inferredMaskType << ") and mask operand type (" << maskType
5188 << ") don't match";
5189
5190 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
5191 return op->emitOpError("expects the in_bounds attr of same rank "
5192 "as permutation_map results: ")
5193 << AffineMapAttr::get(permutationMap)
5194 << " vs inBounds of size: " << inBounds.size();
5195
5196 return success();
5197}
5198
5199static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
5200 SmallVector<StringRef, 3> elidedAttrs;
5201 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5202 if (op.getPermutationMap().isMinorIdentity())
5203 elidedAttrs.push_back(op.getPermutationMapAttrName());
5204 // Elide in_bounds attribute if all dims are out-of-bounds.
5205 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
5206 elidedAttrs.push_back(op.getInBoundsAttrName());
5207 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5208}
5209
5210void TransferReadOp::print(OpAsmPrinter &p) {
5211 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
5212 if (getMask())
5213 p << ", " << getMask();
5214 printTransferAttrs(p, *this);
5215 p << " : " << getShapedType() << ", " << getVectorType();
5216}
5217
5218VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
5219 AffineMap permMap) {
5220 auto i1Type = IntegerType::get(permMap.getContext(), 1);
5221 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
5222 assert(invPermMap && "Inversed permutation map couldn't be computed");
5223 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
5224
5225 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
5226 // 0-D mask into a single-element 1-D mask.
5227 if (maskShape.empty())
5228 maskShape.push_back(1);
5229
5230 SmallVector<bool> scalableDims =
5231 applyPermutationMap(invPermMap, vecType.getScalableDims());
5232
5233 return VectorType::get(maskShape, i1Type, scalableDims);
5234}
5235
5236ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
5237 auto &builder = parser.getBuilder();
5238 SMLoc typesLoc;
5244 // Parsing with support for paddingValue.
5245 if (parser.parseOperand(sourceInfo) ||
5247 parser.parseComma() || parser.parseOperand(paddingInfo))
5248 return failure();
5249 ParseResult hasMask = parser.parseOptionalComma();
5250 if (hasMask.succeeded()) {
5251 if (parser.parseOperand(maskInfo))
5252 return failure();
5253 }
5254 if (parser.parseOptionalAttrDict(result.attributes) ||
5255 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5256 return failure();
5257 if (types.size() != 2)
5258 return parser.emitError(typesLoc, "requires two types");
5259 auto indexType = builder.getIndexType();
5260 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5261 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5262 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5263 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5264 if (!vectorType)
5265 return parser.emitError(typesLoc, "requires vector type");
5266 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
5267 Attribute permMapAttr = result.attributes.get(permMapAttrName);
5268 AffineMap permMap;
5269 if (!permMapAttr) {
5270 if (shapedType.getRank() <
5271 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5272 return parser.emitError(typesLoc,
5273 "expected a custom permutation_map when "
5274 "rank(source) != rank(destination)");
5275 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5276 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5277 } else {
5278 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5279 }
5280 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
5281 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5282 if (!inBoundsAttr) {
5283 result.addAttribute(inBoundsAttrName,
5284 builder.getBoolArrayAttr(
5285 SmallVector<bool>(permMap.getNumResults(), false)));
5286 }
5287 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5288 parser.resolveOperands(indexInfo, indexType, result.operands) ||
5289 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
5290 result.operands))
5291 return failure();
5292 if (hasMask.succeeded()) {
5293 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5294 return parser.emitError(
5295 maskInfo.location, "does not support masks with vector element type");
5296 if (vectorType.getRank() != permMap.getNumResults()) {
5297 return parser.emitError(typesLoc,
5298 "expected the same rank for the vector and the "
5299 "results of the permutation map");
5300 }
5301 // Instead of adding the mask type as an op type, compute it based on the
5302 // vector type and the permutation map (to keep the type signature small).
5303 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5304 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5305 return failure();
5306 }
5307 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5308 builder.getDenseI32ArrayAttr(
5309 {1, static_cast<int32_t>(indexInfo.size()), 1,
5310 static_cast<int32_t>(hasMask.succeeded())}));
5311 return parser.addTypeToList(vectorType, result.types);
5312}
5313
5314LogicalResult TransferReadOp::verify() {
5315 // Consistency of elemental types in source and vector.
5316 ShapedType shapedType = getShapedType();
5317 VectorType vectorType = getVectorType();
5318 VectorType maskType = getMaskType();
5319 auto paddingType = getPadding().getType();
5320 auto permutationMap = getPermutationMap();
5321 VectorType inferredMaskType =
5322 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5323 : VectorType();
5324 auto sourceElementType = shapedType.getElementType();
5325
5326 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
5327 return emitOpError("requires ") << shapedType.getRank() << " indices";
5328
5329 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5330 shapedType, vectorType, maskType,
5331 inferredMaskType, permutationMap, getInBounds())))
5332 return failure();
5333
5334 if (auto sourceVectorElementType =
5335 llvm::dyn_cast<VectorType>(sourceElementType)) {
5336 // Source has vector element type.
5337 // Check that 'sourceVectorElementType' and 'paddingType' types match.
5338 if (sourceVectorElementType != paddingType)
5339 return emitOpError(
5340 "requires source element type and padding type to match.");
5341
5342 } else {
5343 // Check that 'paddingType' is valid to store in a vector type.
5344 if (!VectorType::isValidElementType(paddingType))
5345 return emitOpError("requires valid padding vector elemental type");
5346
5347 // Check that padding type and vector element types match.
5348 if (paddingType != sourceElementType)
5349 return emitOpError(
5350 "requires formal padding and source of the same elemental type");
5351 }
5352
5353 return verifyPermutationMap(permutationMap,
5354 [&](Twine t) { return emitOpError(t); });
5355}
5356
5357// MaskableOpInterface methods.
5358
5359/// Returns the mask type expected by this operation. Mostly used for
5360/// verification purposes. It requires the operation to be vectorized."
5361Type TransferReadOp::getExpectedMaskType() {
5362 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5363}
5364
5365//===----------------------------------------------------------------------===//
5366// TransferReadOp: VectorTransferOpInterface methods.
5367//===----------------------------------------------------------------------===//
5368VectorType TransferReadOp::getVectorType() {
5369 return cast<VectorType>(getVector().getType());
5370}
5371
5372template <typename TransferOp>
5373static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
5374 // TODO: support more aggressive createOrFold on:
5375 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
5376 if (op.getShapedType().isDynamicDim(indicesIdx))
5377 return false;
5378 Value index = op.getIndices()[indicesIdx];
5379 std::optional<int64_t> cstOp = getConstantIntValue(index);
5380 if (!cstOp.has_value())
5381 return false;
5382
5383 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5384 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5385
5386 return cstOp.value() + vectorSize <= sourceSize;
5387}
5388
5389template <typename TransferOp>
5390static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
5391 // TODO: support 0-d corner case.
5392 // TODO: Be less conservative.
5393 if (op.getTransferRank() == 0)
5394 return failure();
5395 AffineMap permutationMap = op.getPermutationMap();
5396 bool changed = false;
5397 SmallVector<bool, 4> newInBounds;
5398 newInBounds.reserve(op.getTransferRank());
5399 // Idxs of non-bcast dims - used when analysing bcast dims.
5400 SmallVector<unsigned> nonBcastDims;
5401
5402 // 1. Process non-broadcast dims
5403 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5404 // 1.1. Already marked as in-bounds, nothing to see here.
5405 if (op.isDimInBounds(i)) {
5406 newInBounds.push_back(true);
5407 continue;
5408 }
5409 // 1.2. Currently out-of-bounds, check whether we can statically determine
5410 // it is inBounds.
5411 bool inBounds = false;
5412 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5413 if (dimExpr) {
5414 inBounds = isInBounds(op, /*resultIdx=*/i,
5415 /*indicesIdx=*/dimExpr.getPosition());
5416 nonBcastDims.push_back(i);
5417 }
5418
5419 newInBounds.push_back(inBounds);
5420 // We commit the pattern if it is "more inbounds".
5421 changed |= inBounds;
5422 }
5423
5424 // 2. Handle broadcast dims
5425 // If all non-broadcast dims are "in bounds", then all bcast dims should be
5426 // "in bounds" as well.
5427 bool allNonBcastDimsInBounds = llvm::all_of(
5428 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5429 if (allNonBcastDimsInBounds) {
5430 for (size_t idx : permutationMap.getBroadcastDims()) {
5431 changed |= !newInBounds[idx];
5432 newInBounds[idx] = true;
5433 }
5434 }
5435
5436 if (!changed)
5437 return failure();
5438 // OpBuilder is only used as a helper to build an I64ArrayAttr.
5439 OpBuilder b(op.getContext());
5440 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5441 return success();
5442}
5443
5444template <typename TransferOp>
5445static LogicalResult foldTransferFullMask(TransferOp op) {
5446 auto mask = op.getMask();
5447 if (!mask)
5448 return failure();
5449
5451 return failure();
5452
5453 op.getMaskMutable().clear();
5454 return success();
5455}
5456
5457/// When the vector type is `vector<1xT>`, the permutation map is irrelevant:
5458/// the single vector lane always has iteration offset 0, so the element is at
5459/// `indices` regardless of which source dimension the map points at. Replace
5460/// with the minor identity to unblock lowering to vector.load / vector.store.
5461template <typename TransferOp>
5462static LogicalResult foldSize1TransferPermutationMap(TransferOp op) {
5463 VectorType vecType = op.getVectorType();
5464 if (vecType.getRank() != 1 || vecType.getShape()[0] != 1 ||
5465 vecType.isScalable())
5466 return failure();
5467
5468 AffineMap map = op.getPermutationMap();
5469 if (map.isMinorIdentity())
5470 return failure();
5471
5472 int64_t srcRank = op.getShapedType().getRank();
5473 if (srcRank < 1)
5474 return failure();
5475
5476 AffineMap minorIdentity =
5477 AffineMap::getMinorIdentityMap(srcRank, 1, op.getContext());
5478 op.setPermutationMapAttr(AffineMapAttr::get(minorIdentity));
5479 return success();
5480}
5481
5482/// ```
5483/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5484/// : vector<1x4xf32>, tensor<4x4xf32>
5485/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5486/// : tensor<4x4xf32>, vector<1x4xf32>
5487/// ```
5488/// -> Folds into
5489/// ```
5490/// %v0
5491/// ```
5492static Value foldRAW(TransferReadOp readOp) {
5493 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5494 return {};
5495 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5496 while (defWrite) {
5497 if (checkSameValueRAW(defWrite, readOp))
5498 return defWrite.getVector();
5500 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5501 cast<VectorTransferOpInterface>(readOp.getOperation())))
5502 break;
5503 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5504 }
5505 return {};
5506}
5507
5508OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5509 if (Value vec = foldRAW(*this))
5510 return vec;
5511 /// transfer_read(memrefcast) -> transfer_read
5512 if (succeeded(foldTransferInBoundsAttribute(*this)))
5513 return getResult();
5514 if (succeeded(foldTransferFullMask(*this)))
5515 return getResult();
5516 if (succeeded(foldSize1TransferPermutationMap(*this)))
5517 return getResult();
5518 if (succeeded(memref::foldMemRefCast(*this)))
5519 return getResult();
5520 if (succeeded(tensor::foldTensorCast(*this)))
5521 return getResult();
5522 return OpFoldResult();
5523}
5524
5525std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5526 return llvm::to_vector<4>(getVectorType().getShape());
5527}
5528
5529void TransferReadOp::getEffects(
5530 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5531 &effects) {
5532 if (llvm::isa<MemRefType>(getShapedType()))
5533 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5534 SideEffects::DefaultResource::get());
5535}
5536
5537Speculation::Speculatability TransferReadOp::getSpeculatability() {
5538 if (hasPureTensorSemantics())
5541}
5542
5543/// Given a projected permutation, inverse an affine map, making the unused dims
5544/// 0 in the result.
5545static AffineMap inverseWithUnusedDims(AffineMap map) {
5546 assert(map.isProjectedPermutation() &&
5547 "expected a projected permutation map");
5548 SmallVector<AffineExpr> results(map.getNumInputs(),
5550 for (auto [idx, result] : llvm::enumerate(map.getResults())) {
5551 // We should only have dim exprs because this is a projected permutation.
5552 int64_t pos = cast<AffineDimExpr>(result).getPosition();
5553 results[pos] = getAffineDimExpr(idx, map.getContext());
5554 }
5555 return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
5556 results, map.getContext());
5557}
5558
5559namespace {
5560/// Store to load forwarding for transfer operations with permuation maps.
5561/// Even if the permutation maps are different we can still propagate the store
5562/// into the load if the size of the dimensions read and written match. Then we
5563/// can replace the transfer_read + transfer_write by vector.broadcast and
5564/// vector.transpose.
5565/// Example:
5566/// ```
5567/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5568/// {in_bounds = [true, true],
5569/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5570/// vector<4x1xf32>, tensor<4x4x4xf32>
5571/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5572/// {in_bounds = [true, true, true, true],
5573/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5574/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5575/// ```
5576/// To:
5577/// ```
5578/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5579/// %r = vector.transpose %0, [3, 0, 2, 1] :
5580/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5581/// ```
5582struct TransferReadAfterWriteToBroadcast
5583 : public OpRewritePattern<TransferReadOp> {
5584 using Base::Base;
5585
5586 LogicalResult matchAndRewrite(TransferReadOp readOp,
5587 PatternRewriter &rewriter) const override {
5588 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5589 if (!defWrite)
5590 return failure();
5591 // Bail if we need an alias analysis.
5592 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5593 return failure();
5594 // Bail in the masked case (too complex atm and needed to properly account
5595 // for padding).
5596 if (readOp.getMask() || defWrite.getMask())
5597 return failure();
5598 // If indices are not the same a shift may be required, bail.
5599 if (readOp.getIndices() != defWrite.getIndices())
5600 return failure();
5601 // Bail if we need a bounds analysis.
5602 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5603 return failure();
5604 // TODO: If the written transfer chunk is a superset of the read transfer
5605 // chunk we could do an extract_strided_slice.
5606 if (readOp.getTransferChunkAccessed() !=
5607 defWrite.getTransferChunkAccessed())
5608 return failure();
5609 // WriteMap: tensor -> w_vec
5610 // ReadMap: tensor -> r_vec
5611 //
5612 // inv(WriteMap): w_vec -> tensor
5613 // inv(WriteMap) o ReadMap: w_vec -> r_vec
5614 AffineMap readMap = readOp.getPermutationMap();
5615 AffineMap writeMap = defWrite.getPermutationMap();
5616 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5617 AffineMap composedMap = readMap.compose(invWriteMap);
5618 // If there are any unused dims in the composedMap, we have to drop some
5619 // unit dims from the written vector before we can do transpose(broadcast).
5620 // TODO: Support this case.
5621 if (getUnusedDimsBitVector(composedMap).any())
5622 return failure();
5623 // readVec = transpose(broadcast(writeVec))
5624 //
5625 // Build a transpose permutation for the above transpose operation.
5626 //
5627 // Treat the composed map as having extra leading dimensions which are
5628 // the broadcasted dimensions, and treat the zeros as these new broadcasted
5629 // dimensions.
5630 SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
5631 int64_t numBroadcastedDims = broadcastedDims.size();
5632 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5633 invPerm.resize(composedMap.getNumResults());
5634 for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
5635 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5636 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5637 invPerm[effectiveDim] = idx;
5638 }
5639 }
5640 // Applying the inverse permutation on the readVecTy will give us the
5641 // broadcast result type.
5642 VectorType readVecTy = readOp.getVectorType();
5643 SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
5644 auto broadcastedVecTy =
5645 VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
5646 readVecTy.getElementType(),
5647 applyPermutation(readVecTy.getScalableDims(), invPerm));
5648 // Build the transpose(broadcast) transformation.
5649 Value vec = defWrite.getVector();
5650 Location loc = readOp.getLoc();
5651 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5652 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
5653 return success();
5654 }
5655};
5656} // namespace
5657
5658void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5659 MLIRContext *context) {
5660 results.add<TransferReadAfterWriteToBroadcast>(context);
5661}
5662
5663FailureOr<std::optional<SmallVector<Value>>>
5664TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5665 if (!hasPureBufferSemantics())
5666 return failure();
5668 getResult());
5669}
5670
5671//===----------------------------------------------------------------------===//
5672// TransferWriteOp
5673//===----------------------------------------------------------------------===//
5674
5675/// 1. Builder with type inference.
5676void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5677 Value vector, Value dest, ValueRange indices,
5678 AffineMapAttr permutationMapAttr,
5679 /*optional*/ Value mask,
5680 /*optional*/ ArrayAttr inBoundsAttr) {
5681 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5682 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5683 mask, inBoundsAttr);
5684}
5685
5686/// 2. Builder with type inference that sets an empty mask (variant with attrs).
5687void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5688 Value vector, Value dest, ValueRange indices,
5689 AffineMapAttr permutationMapAttr,
5690 /*optional*/ ArrayAttr inBoundsAttr) {
5691 build(builder, result, vector, dest, indices, permutationMapAttr,
5692 /*mask=*/Value(), inBoundsAttr);
5693}
5694
5695/// 3. Builder with type inference that sets an empty mask (variant without
5696/// attrs). If `permutationMap` is null, a minor identity map is used.
5697void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5698 Value vector, Value dest, ValueRange indices,
5699 AffineMap permutationMap,
5700 std::optional<ArrayRef<bool>> inBounds) {
5701 if (!permutationMap)
5702 permutationMap =
5703 getTransferMinorIdentityMap(llvm::cast<ShapedType>(dest.getType()),
5704 llvm::cast<VectorType>(vector.getType()));
5705 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5706 auto inBoundsAttr =
5707 (inBounds && !inBounds.value().empty())
5708 ? builder.getBoolArrayAttr(inBounds.value())
5709 : builder.getBoolArrayAttr(SmallVector<bool>(
5710 llvm::cast<VectorType>(vector.getType()).getRank(), false));
5711 build(builder, result, vector, dest, indices, permutationMapAttr,
5712 /*mask=*/Value(), inBoundsAttr);
5713}
5714
5715/// 4. Builder with type inference that sets an empty mask and sets permutation
5716/// map to 'getMinorIdentityMap'.
5717void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5718 Value vector, Value dest, ValueRange indices,
5719 std::optional<ArrayRef<bool>> inBounds) {
5720 build(builder, result, vector, dest, indices, /*permutationMap=*/AffineMap(),
5721 inBounds);
5722}
5723
5724ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5725 OperationState &result) {
5726 auto &builder = parser.getBuilder();
5727 SMLoc typesLoc;
5728 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5729 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5730 SmallVector<Type, 2> types;
5731 OpAsmParser::UnresolvedOperand maskInfo;
5732 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5733 parser.parseOperand(sourceInfo) ||
5734 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5735 return failure();
5736 ParseResult hasMask = parser.parseOptionalComma();
5737 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5738 return failure();
5739 if (parser.parseOptionalAttrDict(result.attributes) ||
5740 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5741 return failure();
5742 if (types.size() != 2)
5743 return parser.emitError(typesLoc, "requires two types");
5744 auto indexType = builder.getIndexType();
5745 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5746 if (!vectorType)
5747 return parser.emitError(typesLoc, "requires vector type");
5748 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5749 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5750 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5751 auto permMapAttrName =
5752 TransferWriteOp::getPermutationMapAttrName(result.name);
5753 auto permMapAttr = result.attributes.get(permMapAttrName);
5754 AffineMap permMap;
5755 if (!permMapAttr) {
5756 if (shapedType.getRank() <
5757 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5758 return parser.emitError(typesLoc,
5759 "expected a custom permutation_map when "
5760 "rank(source) != rank(destination)");
5761 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5762 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5763 } else {
5764 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5765 }
5766 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5767 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5768 if (!inBoundsAttr) {
5769 result.addAttribute(inBoundsAttrName,
5770 builder.getBoolArrayAttr(
5771 SmallVector<bool>(permMap.getNumResults(), false)));
5772 }
5773 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5774 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5775 parser.resolveOperands(indexInfo, indexType, result.operands))
5776 return failure();
5777 if (hasMask.succeeded()) {
5778 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5779 return parser.emitError(
5780 maskInfo.location, "does not support masks with vector element type");
5781 if (vectorType.getRank() != permMap.getNumResults()) {
5782 return parser.emitError(typesLoc,
5783 "expected the same rank for the vector and the "
5784 "results of the permutation map");
5785 }
5786 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5787 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5788 return failure();
5789 }
5790 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5791 builder.getDenseI32ArrayAttr(
5792 {1, 1, static_cast<int32_t>(indexInfo.size()),
5793 static_cast<int32_t>(hasMask.succeeded())}));
5794 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5795 parser.addTypeToList(shapedType, result.types));
5796}
5797
5798void TransferWriteOp::print(OpAsmPrinter &p) {
5799 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5800 if (getMask())
5801 p << ", " << getMask();
5802 printTransferAttrs(p, *this);
5803 p << " : " << getVectorType() << ", " << getShapedType();
5804}
5805
5806LogicalResult TransferWriteOp::verify() {
5807 // Consistency of elemental types in shape and vector.
5808 ShapedType shapedType = getShapedType();
5809 VectorType vectorType = getVectorType();
5810 VectorType maskType = getMaskType();
5811 auto permutationMap = getPermutationMap();
5812 VectorType inferredMaskType =
5813 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5814 : VectorType();
5815
5816 if (llvm::size(getIndices()) != shapedType.getRank())
5817 return emitOpError("requires ") << shapedType.getRank() << " indices";
5818
5819 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5820 // as the semantics is unclear. This can be revisited later if necessary.
5821 if (hasBroadcastDim())
5822 return emitOpError("should not have broadcast dimensions");
5823
5824 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5825 shapedType, vectorType, maskType,
5826 inferredMaskType, permutationMap, getInBounds())))
5827 return failure();
5828
5829 return verifyPermutationMap(permutationMap,
5830 [&](Twine t) { return emitOpError(t); });
5831}
5832
5833//===----------------------------------------------------------------------===//
5834// TransferWriteOp: MaskableOpInterface methods.
5835//===----------------------------------------------------------------------===//
5836
5837/// Returns the mask type expected by this operation. Mostly used for
5838/// verification purposes.
5839Type TransferWriteOp::getExpectedMaskType() {
5840 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5841}
5842
5843//===----------------------------------------------------------------------===//
5844// TransferWriteOp: VectorTransferOpInterface methods.
5845//===----------------------------------------------------------------------===//
5846Value TransferWriteOp::getVector() { return getOperand(0); }
5847VectorType TransferWriteOp::getVectorType() {
5848 return cast<VectorType>(getValueToStore().getType());
5849}
5850
5851//===----------------------------------------------------------------------===//
5852// TransferWriteOp: fold methods.
5853//===----------------------------------------------------------------------===//
5854/// Fold:
5855/// ```
5856/// %t1 = ...
5857/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5858/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5859/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5860/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5861/// ```
5862///
5863/// into:
5864///
5865/// ```
5866/// %t0
5867/// ```
5868///
5869/// The producer of t1 may or may not be DCE'd depending on whether it is a
5870/// block argument or has side effects.
5871static LogicalResult foldReadInitWrite(TransferWriteOp write,
5872 ArrayRef<Attribute>,
5873 SmallVectorImpl<OpFoldResult> &results) {
5874 // TODO: support 0-d corner case.
5875 if (write.getTransferRank() == 0)
5876 return failure();
5877 auto rankedTensorType =
5878 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5879 // If not operating on tensors, bail.
5880 if (!rankedTensorType)
5881 return failure();
5882 // If no read, bail.
5883 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5884 if (!read)
5885 return failure();
5886 // TODO: support 0-d corner case.
5887 if (read.getTransferRank() == 0)
5888 return failure();
5889 // For now, only accept minor identity. Future: composition is minor identity.
5890 if (!read.getPermutationMap().isMinorIdentity() ||
5891 !write.getPermutationMap().isMinorIdentity())
5892 return failure();
5893 // Bail on mismatching ranks.
5894 if (read.getTransferRank() != write.getTransferRank())
5895 return failure();
5896 // Bail on potential out-of-bounds accesses.
5897 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5898 return failure();
5899 // Masked transfers have padding/select semantics and are not identity folds.
5900 if (read.getMask() || write.getMask())
5901 return failure();
5902 // Tensor types must be the same.
5903 if (read.getBase().getType() != rankedTensorType)
5904 return failure();
5905 // Vector types must be the same.
5906 if (read.getVectorType() != write.getVectorType())
5907 return failure();
5908 // Vector and Tensor shapes must match.
5909 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5910 return failure();
5911 // If any index is nonzero.
5912 auto isNotConstantZero = [](Value v) {
5913 auto cstOp = getConstantIntValue(v);
5914 return !cstOp.has_value() || cstOp.value() != 0;
5915 };
5916 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5917 llvm::any_of(write.getIndices(), isNotConstantZero))
5918 return failure();
5919 // Success.
5920 results.push_back(read.getBase());
5921 return success();
5922}
5923
5924static bool checkSameValueWAR(vector::TransferReadOp read,
5925 vector::TransferWriteOp write) {
5926 return read.getBase() == write.getBase() &&
5927 read.getIndices() == write.getIndices() &&
5928 read.getPermutationMap() == write.getPermutationMap() &&
5929 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5930 !write.getMask();
5931}
5932/// Fold transfer_write write after read:
5933/// ```
5934/// %t0 = ...
5935/// %v = vector.transfer_read %t0[%c0...] :
5936/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5937/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5938/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5939/// ```
5940///
5941/// into:
5942///
5943/// ```
5944/// %t0
5945/// ```
5946static LogicalResult foldWAR(TransferWriteOp write,
5947 SmallVectorImpl<OpFoldResult> &results) {
5948 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5949 return failure();
5950 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5951 if (!read)
5952 return failure();
5953
5954 if (!checkSameValueWAR(read, write))
5955 return failure();
5956 results.push_back(read.getBase());
5957 return success();
5958}
5959
5960LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5961 SmallVectorImpl<OpFoldResult> &results) {
5962 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5963 return success();
5964 if (succeeded(foldWAR(*this, results)))
5965 return success();
5966 if (succeeded(foldTransferInBoundsAttribute(*this)))
5967 return success();
5968 if (succeeded(foldTransferFullMask(*this)))
5969 return success();
5970 if (succeeded(foldSize1TransferPermutationMap(*this)))
5971 return success();
5972 return memref::foldMemRefCast(*this);
5973}
5974
5975//===----------------------------------------------------------------------===//
5976// TransferWriteOp: other methods.
5977//===----------------------------------------------------------------------===//
5978std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5979 return llvm::to_vector<4>(getVectorType().getShape());
5980}
5981
5982void TransferWriteOp::getEffects(
5983 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5984 &effects) {
5985 if (llvm::isa<MemRefType>(getShapedType()))
5986 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5987 SideEffects::DefaultResource::get());
5988}
5989
5990Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5991 if (hasPureTensorSemantics())
5994}
5995
5996namespace {
5997/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5998/// DCE
5999/// ```
6000/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
6001/// : vector<1x4xf32>, tensor<4x4xf32>
6002/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
6003/// : vector<1x4xf32>, tensor<4x4xf32>
6004/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
6005/// : vector<1x4xf32>, tensor<4x4xf32>
6006/// ```
6007///
6008/// into:
6009///
6010/// ```
6011/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
6012/// : vector<1x4xf32>, tensor<4x4xf32>
6013/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
6014/// : vector<1x4xf32>, tensor<4x4xf32>
6015/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
6016/// : vector<1x4xf32>, tensor<4x4xf32>
6017/// ```
6018///
6019/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
6020/// any other uses.
6021class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
6022public:
6023 using Base::Base;
6024 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
6025 PatternRewriter &rewriter) const override {
6026 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
6027 return failure();
6028 vector::TransferWriteOp writeToModify = writeOp;
6029
6030 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
6031 while (defWrite) {
6032 if (checkSameValueWAW(writeOp, defWrite)) {
6033 rewriter.modifyOpInPlace(writeToModify, [&]() {
6034 writeToModify.getBaseMutable().assign(defWrite.getBase());
6035 });
6036 return success();
6037 }
6039 cast<VectorTransferOpInterface>(defWrite.getOperation()),
6040 cast<VectorTransferOpInterface>(writeOp.getOperation())))
6041 break;
6042 // If the previous write op doesn't have any other use we an safely look
6043 // at the previous store to see if it can be removed.
6044 if (!defWrite->hasOneUse())
6045 break;
6046 writeToModify = defWrite;
6047 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
6048 }
6049 return failure();
6050 }
6051};
6052
6053/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
6054/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
6055/// overwritten and inserted into another tensor. After this rewrite, the
6056/// operations bufferize in-place since all of them work on the same slice.
6057///
6058/// For example:
6059/// ```mlir
6060/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
6061/// : vector<8x16xf32>, tensor<8x16xf32>
6062/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
6063/// : tensor<8x16xf32> to tensor<?x?xf32>
6064/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6065/// : tensor<?x?xf32> into tensor<27x37xf32>
6066/// ```
6067/// folds to
6068/// ```mlir
6069/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6070/// : tensor<27x37xf32> to tensor<?x?xf32>
6071/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
6072/// : vector<8x16xf32>, tensor<?x?xf32>
6073/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6074/// : tensor<?x?xf32> into tensor<27x37xf32>
6075/// ```
6076struct SwapExtractSliceOfTransferWrite
6077 : public OpRewritePattern<tensor::InsertSliceOp> {
6078public:
6079 using Base::Base;
6080
6081 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
6082 PatternRewriter &rewriter) const override {
6083 if (!insertOp.hasUnitStride())
6084 return failure();
6085 auto extractOp =
6086 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
6087 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
6088 return failure();
6089 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
6090 if (!transferOp || !transferOp->hasOneUse())
6091 return failure();
6092
6093 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
6094 // rank-reducing.
6095 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
6096 return rewriter.notifyMatchFailure(insertOp,
6097 "use-def chain is rank-reducing");
6098 }
6099
6100 // Fail if tensor::ExtractSliceOp has non-zero offset.
6101 if (!extractOp.hasZeroOffset()) {
6102 return rewriter.notifyMatchFailure(insertOp,
6103 "ExtractSliceOp has non-zero offset");
6104 }
6105
6106 // Fail if tensor::TransferWriteOp has non-zero offset.
6107 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
6108 return getConstantIntValue(value) == static_cast<int64_t>(0);
6109 })) {
6110 return rewriter.notifyMatchFailure(insertOp,
6111 "TranferWriteOp has non-zero offset");
6112 }
6113
6114 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
6115 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
6116 return rewriter.notifyMatchFailure(
6117 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
6118 }
6119
6120 for (auto [insertSize, extractSize] :
6121 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
6122 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
6123 return rewriter.notifyMatchFailure(
6124 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
6125 }
6126 }
6127
6128 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
6129 assert(transferOp.getVectorType().hasStaticShape() &&
6130 "expected vector to have a static shape");
6131 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
6132 SmallVector<int64_t> resultShape = applyPermutationMap(
6133 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
6134 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
6135 return rewriter.notifyMatchFailure(
6136 insertOp, "TransferWriteOp may not write the full tensor.");
6137 }
6138
6139 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
6140 // Set all in_bounds to false and let the folder infer them.
6141 SmallVector<bool> newInBounds(vectorShape.size(), false);
6142 auto newExtractOp = tensor::ExtractSliceOp::create(
6143 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
6144 insertOp.getDest(), insertOp.getMixedOffsets(),
6145 insertOp.getMixedSizes(), insertOp.getMixedStrides());
6146 auto newTransferWriteOp = TransferWriteOp::create(
6147 rewriter, transferOp.getLoc(), transferOp.getVector(),
6148 newExtractOp.getResult(), transferOp.getIndices(),
6149 transferOp.getPermutationMapAttr(),
6150 rewriter.getBoolArrayAttr(newInBounds));
6151 rewriter.modifyOpInPlace(insertOp, [&]() {
6152 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
6153 });
6154 return success();
6155 }
6156};
6157
6158} // namespace
6159
6160void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
6161 MLIRContext *context) {
6162 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
6163}
6164
6165FailureOr<std::optional<SmallVector<Value>>>
6166TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
6167 if (!hasPureBufferSemantics())
6168 return failure();
6170 ValueRange());
6171}
6172
6173//===----------------------------------------------------------------------===//
6174// LoadOp
6175//===----------------------------------------------------------------------===//
6176
6177static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
6178 VectorType vecTy,
6179 MemRefType memRefTy) {
6180 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
6181 // need any strides limitations.
6182 if (!vecTy.isScalable() &&
6183 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
6184 return success();
6185
6186 if (!memRefTy.isLastDimUnitStride())
6187 return op->emitOpError("most minor memref dim must have unit stride");
6188 return success();
6189}
6190
6191LogicalResult vector::LoadOp::verify() {
6192 VectorType resVecTy = getVectorType();
6193 MemRefType memRefTy = getMemRefType();
6194
6195 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
6196 return failure();
6197
6198 // Negative strides are not supported on vector.load.
6199 if (memref::hasNegativeStaticStride(memRefTy))
6200 return emitOpError("memref strides must be non-negative");
6201
6202 if (memRefTy.getRank() < resVecTy.getRank())
6203 return emitOpError(
6204 "destination memref has lower rank than the result vector");
6205
6206 // Checks for vector memrefs.
6207 Type memElemTy = memRefTy.getElementType();
6208 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6209 if (memVecTy != resVecTy)
6210 return emitOpError("base memref and result vector types should match");
6211 memElemTy = memVecTy.getElementType();
6212 }
6213
6214 if (resVecTy.getElementType() != memElemTy)
6215 return emitOpError("base and result element types should match");
6216 if (llvm::size(getIndices()) != memRefTy.getRank())
6217 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6218 return success();
6219}
6220
6221OpFoldResult LoadOp::fold(FoldAdaptor) {
6222 if (succeeded(memref::foldMemRefCast(*this)))
6223 return getResult();
6224 return OpFoldResult();
6225}
6226
6227std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6228 return llvm::to_vector<4>(getVectorType().getShape());
6229}
6230
6231FailureOr<std::optional<SmallVector<Value>>>
6232LoadOp::bubbleDownCasts(OpBuilder &builder) {
6234 getResult());
6235}
6236
6237//===----------------------------------------------------------------------===//
6238// StoreOp
6239//===----------------------------------------------------------------------===//
6240
6241LogicalResult vector::StoreOp::verify() {
6242 VectorType valueVecTy = getVectorType();
6243 MemRefType memRefTy = getMemRefType();
6244
6245 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
6246 return failure();
6247
6248 // Negative strides are not supported on vector.store.
6249 if (memref::hasNegativeStaticStride(memRefTy))
6250 return emitOpError("memref strides must be non-negative");
6251
6252 if (memRefTy.getRank() < valueVecTy.getRank())
6253 return emitOpError("source memref has lower rank than the vector to store");
6254
6255 // Checks for vector memrefs.
6256 Type memElemTy = memRefTy.getElementType();
6257 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6258 if (memVecTy != valueVecTy)
6259 return emitOpError(
6260 "base memref and valueToStore vector types should match");
6261 memElemTy = memVecTy.getElementType();
6262 }
6263
6264 if (valueVecTy.getElementType() != memElemTy)
6265 return emitOpError("base and valueToStore element type should match");
6266 if (llvm::size(getIndices()) != memRefTy.getRank())
6267 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6268 return success();
6269}
6270
6271LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6272 SmallVectorImpl<OpFoldResult> &results) {
6273 return memref::foldMemRefCast(*this);
6274}
6275
6276std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6277 return llvm::to_vector<4>(getVectorType().getShape());
6278}
6279
6280FailureOr<std::optional<SmallVector<Value>>>
6281StoreOp::bubbleDownCasts(OpBuilder &builder) {
6283 ValueRange());
6284}
6285
6286//===----------------------------------------------------------------------===//
6287// MaskedLoadOp
6288//===----------------------------------------------------------------------===//
6289
6290LogicalResult MaskedLoadOp::verify() {
6291 VectorType maskVType = getMaskVectorType();
6292 VectorType passVType = getPassThruVectorType();
6293 VectorType resVType = getVectorType();
6294 MemRefType memType = getMemRefType();
6295
6296 if (failed(
6297 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6298 return failure();
6299 if (llvm::size(getIndices()) != memType.getRank())
6300 return emitOpError("requires ") << memType.getRank() << " indices";
6301 if (resVType.getShape() != maskVType.getShape())
6302 return emitOpError("expected result shape to match mask shape");
6303 if (resVType != passVType)
6304 return emitOpError("expected pass_thru of same type as result type");
6305 return success();
6306}
6307
6308namespace {
6309class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
6310public:
6311 using Base::Base;
6312 LogicalResult matchAndRewrite(MaskedLoadOp load,
6313 PatternRewriter &rewriter) const override {
6314 switch (getMaskFormat(load.getMask())) {
6316 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6317 load, load.getType(), load.getBase(), load.getIndices());
6318 return success();
6320 rewriter.replaceOp(load, load.getPassThru());
6321 return success();
6323 return failure();
6324 }
6325 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
6326 }
6327};
6328} // namespace
6329
6330void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6331 MLIRContext *context) {
6332 results.add<MaskedLoadFolder>(context);
6333}
6334
6335OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6336 if (succeeded(memref::foldMemRefCast(*this)))
6337 return getResult();
6338 return OpFoldResult();
6339}
6340
6341FailureOr<std::optional<SmallVector<Value>>>
6342MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6344 getResult());
6345}
6346
6347//===----------------------------------------------------------------------===//
6348// MaskedStoreOp
6349//===----------------------------------------------------------------------===//
6350
6351LogicalResult MaskedStoreOp::verify() {
6352 VectorType maskVType = getMaskVectorType();
6353 VectorType valueVType = getVectorType();
6354 MemRefType memType = getMemRefType();
6355
6356 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6357 "valueToStore")))
6358 return failure();
6359 if (llvm::size(getIndices()) != memType.getRank())
6360 return emitOpError("requires ") << memType.getRank() << " indices";
6361 if (valueVType.getShape() != maskVType.getShape())
6362 return emitOpError("expected valueToStore shape to match mask shape");
6363 return success();
6364}
6365
6366namespace {
6367class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
6368public:
6369 using Base::Base;
6370 LogicalResult matchAndRewrite(MaskedStoreOp store,
6371 PatternRewriter &rewriter) const override {
6372 switch (getMaskFormat(store.getMask())) {
6374 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6375 store, store.getValueToStore(), store.getBase(), store.getIndices());
6376 return success();
6378 rewriter.eraseOp(store);
6379 return success();
6381 return failure();
6382 }
6383 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
6384 }
6385};
6386} // namespace
6387
6388void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6389 MLIRContext *context) {
6390 results.add<MaskedStoreFolder>(context);
6391}
6392
6393LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6394 SmallVectorImpl<OpFoldResult> &results) {
6395 return memref::foldMemRefCast(*this);
6396}
6397
6398FailureOr<std::optional<SmallVector<Value>>>
6399MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6401 ValueRange());
6402}
6403
6404//===----------------------------------------------------------------------===//
6405// GatherOp
6406//===----------------------------------------------------------------------===//
6407
6408LogicalResult GatherOp::verify() {
6409 VectorType indVType = getIndexVectorType();
6410 VectorType maskVType = getMaskVectorType();
6411 VectorType resVType = getVectorType();
6412 ShapedType baseType = getBaseType();
6413
6414 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6415 return emitOpError("requires base to be a memref or ranked tensor type");
6416
6417 if (failed(
6418 verifyElementTypesMatch(*this, baseType, resVType, "base", "result")))
6419 return failure();
6420 if (llvm::size(getOffsets()) != baseType.getRank())
6421 return emitOpError("requires ") << baseType.getRank() << " indices";
6422 if (resVType.getShape() != indVType.getShape())
6423 return emitOpError("expected result dim to match indices dim");
6424 if (resVType.getShape() != maskVType.getShape())
6425 return emitOpError("expected result dim to match mask dim");
6426 if (resVType != getPassThruVectorType())
6427 return emitOpError("expected pass_thru of same type as result type");
6428 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6429 return emitOpError(
6430 "alignment is only supported for memref bases, not tensor bases");
6431 }
6432 return success();
6433}
6434
6435// MaskableOpInterface methods.
6436
6437/// Returns the mask type expected by this operation. Mostly used for
6438/// verification purposes. It requires the operation to be vectorized."
6439Type GatherOp::getExpectedMaskType() {
6440 auto vecType = this->getIndexVectorType();
6441 return VectorType::get(vecType.getShape(),
6442 IntegerType::get(vecType.getContext(), /*width=*/1),
6443 vecType.getScalableDims());
6444}
6445
6446std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6447 return llvm::to_vector<4>(getVectorType().getShape());
6448}
6449
6450/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
6451static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6452 auto vecType = dyn_cast<VectorType>(indexVec.getType());
6453 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6454 return failure();
6455
6456 if (indexVec.getDefiningOp<StepOp>())
6457 return success();
6458
6459 DenseIntElementsAttr elements;
6460 if (!matchPattern(indexVec, m_Constant(&elements)))
6461 return failure();
6462
6463 return success(
6464 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6465}
6466
6467namespace {
6468class GatherFolder final : public OpRewritePattern<GatherOp> {
6469public:
6470 using Base::Base;
6471 LogicalResult matchAndRewrite(GatherOp gather,
6472 PatternRewriter &rewriter) const override {
6473 switch (getMaskFormat(gather.getMask())) {
6475 return failure(); // no unmasked equivalent
6477 rewriter.replaceOp(gather, gather.getPassThru());
6478 return success();
6480 return failure();
6481 }
6482 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6483 }
6484};
6485
6486/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6487/// maskedload. Only 1D fixed vectors are supported for now.
6488class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6489public:
6490 using Base::Base;
6491 LogicalResult matchAndRewrite(GatherOp op,
6492 PatternRewriter &rewriter) const override {
6493 if (!isa<MemRefType>(op.getBase().getType()))
6494 return rewriter.notifyMatchFailure(op, "base must be of memref type");
6495
6496 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6497 return failure();
6498
6499 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6500 op.getOffsets(), op.getMask(),
6501 op.getPassThru());
6502 return success();
6503 }
6504};
6505} // namespace
6506
6507void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6508 MLIRContext *context) {
6509 results.add<GatherFolder, FoldContiguousGather>(context);
6510}
6511
6512FailureOr<std::optional<SmallVector<Value>>>
6513GatherOp::bubbleDownCasts(OpBuilder &builder) {
6515 getResult());
6516}
6517
6518//===----------------------------------------------------------------------===//
6519// ScatterOp
6520//===----------------------------------------------------------------------===//
6521
6522LogicalResult ScatterOp::verify() {
6523 VectorType indVType = getIndexVectorType();
6524 VectorType maskVType = getMaskVectorType();
6525 VectorType valueVType = getVectorType();
6526 ShapedType baseType = getBaseType();
6527
6528 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6529 return emitOpError("requires base to be a memref or ranked tensor type");
6530
6531 if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
6532 "valueToStore")))
6533 return failure();
6534 if (llvm::size(getOffsets()) != baseType.getRank())
6535 return emitOpError("requires ") << baseType.getRank() << " indices";
6536 if (valueVType.getShape() != indVType.getShape())
6537 return emitOpError("expected valueToStore dim to match indices dim");
6538 if (valueVType.getShape() != maskVType.getShape())
6539 return emitOpError("expected valueToStore dim to match mask dim");
6540 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6541 return emitOpError(
6542 "alignment is only supported for memref bases, not tensor bases");
6543 }
6544 return success();
6545}
6546namespace {
6547class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6548public:
6549 using Base::Base;
6550 LogicalResult matchAndRewrite(ScatterOp scatter,
6551 PatternRewriter &rewriter) const override {
6552 ShapedType baseType = scatter.getBaseType();
6553 bool isMemRef = isa<MemRefType>(baseType);
6554 if (!isMemRef && !isa<RankedTensorType>(baseType))
6555 return failure();
6556
6557 // Memrefs have no result, so an all-false mask can simply erase the op.
6558 // Tensors carry the updated value, so we must replace uses with the
6559 // original base tensor instead of erasing.
6560 switch (getMaskFormat(scatter.getMask())) {
6562 return failure(); // no unmasked equivalent
6564 if (isMemRef)
6565 rewriter.eraseOp(scatter);
6566 else
6567 rewriter.replaceOp(scatter, scatter.getBase());
6568 return success();
6570 return failure();
6571 }
6572 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6573 }
6574};
6575
6576/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6577/// maskedstore. Only 1D fixed vectors are supported for now.
6578class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6579public:
6580 using Base::Base;
6581 LogicalResult matchAndRewrite(ScatterOp op,
6582 PatternRewriter &rewriter) const override {
6583 // Fold only for memrefs: the replacement uses maskedstore, which does not
6584 // support tensor bases. Tensor cases intentionally bail out.
6585 if (!isa<MemRefType>(op.getBase().getType()))
6586 return failure();
6587
6588 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6589 return failure();
6590
6591 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6592 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6593 return success();
6594 }
6595};
6596} // namespace
6597
6598void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6599 MLIRContext *context) {
6600 results.add<ScatterFolder, FoldContiguousScatter>(context);
6601}
6602
6603FailureOr<std::optional<SmallVector<Value>>>
6604ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6606 ValueRange());
6607}
6608
6609//===----------------------------------------------------------------------===//
6610// ExpandLoadOp
6611//===----------------------------------------------------------------------===//
6612
6613LogicalResult ExpandLoadOp::verify() {
6614 VectorType maskVType = getMaskVectorType();
6615 VectorType passVType = getPassThruVectorType();
6616 VectorType resVType = getVectorType();
6617 MemRefType memType = getMemRefType();
6618
6619 if (failed(
6620 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6621 return failure();
6622 if (llvm::size(getIndices()) != memType.getRank())
6623 return emitOpError("requires ") << memType.getRank() << " indices";
6624 if (resVType.getShape() != maskVType.getShape())
6625 return emitOpError("expected result shape to match mask shape");
6626 if (resVType != passVType)
6627 return emitOpError("expected pass_thru of same type as result type");
6628 return success();
6629}
6630
6631namespace {
6632class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6633public:
6634 using Base::Base;
6635 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6636 PatternRewriter &rewriter) const override {
6637 switch (getMaskFormat(expand.getMask())) {
6639 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6640 expand, expand.getType(), expand.getBase(), expand.getIndices());
6641 return success();
6643 rewriter.replaceOp(expand, expand.getPassThru());
6644 return success();
6646 return failure();
6647 }
6648 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6649 }
6650};
6651} // namespace
6652
6653void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6654 MLIRContext *context) {
6655 results.add<ExpandLoadFolder>(context);
6656}
6657
6658FailureOr<std::optional<SmallVector<Value>>>
6659ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6661 getResult());
6662}
6663
6664//===----------------------------------------------------------------------===//
6665// CompressStoreOp
6666//===----------------------------------------------------------------------===//
6667
6668LogicalResult CompressStoreOp::verify() {
6669 VectorType maskVType = getMaskVectorType();
6670 VectorType valueVType = getVectorType();
6671 MemRefType memType = getMemRefType();
6672
6673 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6674 "valueToStore")))
6675 return failure();
6676 if (llvm::size(getIndices()) != memType.getRank())
6677 return emitOpError("requires ") << memType.getRank() << " indices";
6678 if (valueVType.getShape() != maskVType.getShape())
6679 return emitOpError("expected valueToStore shape to match mask shape");
6680 return success();
6681}
6682
6683namespace {
6684class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6685public:
6686 using Base::Base;
6687 LogicalResult matchAndRewrite(CompressStoreOp compress,
6688 PatternRewriter &rewriter) const override {
6689 switch (getMaskFormat(compress.getMask())) {
6691 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6692 compress, compress.getValueToStore(), compress.getBase(),
6693 compress.getIndices());
6694 return success();
6696 rewriter.eraseOp(compress);
6697 return success();
6699 return failure();
6700 }
6701 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6702 }
6703};
6704} // namespace
6705
6706void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6707 MLIRContext *context) {
6708 results.add<CompressStoreFolder>(context);
6709}
6710
6711FailureOr<std::optional<SmallVector<Value>>>
6712CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6714 ValueRange());
6715}
6716
6717//===----------------------------------------------------------------------===//
6718// ShapeCastOp
6719//===----------------------------------------------------------------------===//
6720
6721void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6722 SetIntRangeFn setResultRanges) {
6723 setResultRanges(getResult(), argRanges.front());
6724}
6725
6726std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6727 return llvm::to_vector<4>(getResultVectorType().getShape());
6728}
6729
6730LogicalResult ShapeCastOp::verify() {
6731
6732 VectorType sourceType = getSourceVectorType();
6733 VectorType resultType = getResultVectorType();
6734
6735 // Check that element type is preserved
6736 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
6737 "result")))
6738 return failure();
6739
6740 // Check that number of elements is preserved
6741 int64_t sourceNElms = sourceType.getNumElements();
6742 int64_t resultNElms = resultType.getNumElements();
6743 if (sourceNElms != resultNElms) {
6744 return emitOpError() << "has different number of elements at source ("
6745 << sourceNElms << ") and result (" << resultNElms
6746 << ")";
6747 }
6748
6749 // Check that (non-)scalability is preserved
6750 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6751 int64_t resultNScalableDims = resultType.getNumScalableDims();
6752 if (sourceNScalableDims != resultNScalableDims)
6753 return emitOpError() << "has different number of scalable dims at source ("
6754 << sourceNScalableDims << ") and result ("
6755 << resultNScalableDims << ")";
6756
6757 return success();
6758}
6759
6760/// Return true if `transpose` does not permute a pair of non-unit dims.
6761/// By `order preserving` we mean that the flattened versions of the input and
6762/// output vectors are (numerically) identical. In other words `transpose` is
6763/// effectively a shape cast.
6764static bool isOrderPreserving(TransposeOp transpose) {
6765 ArrayRef<int64_t> permutation = transpose.getPermutation();
6766 VectorType sourceType = transpose.getSourceVectorType();
6767 ArrayRef<int64_t> inShape = sourceType.getShape();
6768 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6769 auto isNonScalableUnitDim = [&](int64_t dim) {
6770 return inShape[dim] == 1 && !inDimIsScalable[dim];
6771 };
6772 int64_t current = 0;
6773 for (auto p : permutation) {
6774 if (!isNonScalableUnitDim(p)) {
6775 if (p < current) {
6776 return false;
6777 }
6778 current = p;
6779 }
6780 }
6781 return true;
6782}
6783
6784OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6785
6786 VectorType resultType = getType();
6787
6788 // No-op shape cast.
6789 if (getSource().getType() == resultType)
6790 return getSource();
6791
6792 // shape_cast(shape_cast(x)) -> shape_cast(x)
6793 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6794 setOperand(precedingShapeCast.getSource());
6795 return getResult();
6796 }
6797
6798 // shape_cast(transpose(x)) -> shape_cast(x)
6799 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6800 if (isOrderPreserving(transpose)) {
6801 setOperand(transpose.getVector());
6802 return getResult();
6803 }
6804 return {};
6805 }
6806
6807 // Y = shape_cast(broadcast(X))
6808 // -> X, if X and Y have same type
6809 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6810 if (bcastOp.getSourceType() == resultType)
6811 return bcastOp.getSource();
6812 }
6813
6814 // shape_cast(constant) -> constant
6815 if (auto denseAttr =
6816 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6817 return denseAttr.reshape(getType());
6818
6819 // shape_cast(poison) -> poison
6820 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
6821 return ub::PoisonAttr::get(getContext());
6822
6823 return {};
6824}
6825
6826namespace {
6827
6828/// Helper function that computes a new vector type based on the input vector
6829/// type by removing the trailing one dims:
6830///
6831/// vector<4x1x1xi1> --> vector<4x1xi1>
6832///
6833static VectorType trimTrailingOneDims(VectorType oldType) {
6834 ArrayRef<int64_t> oldShape = oldType.getShape();
6835 ArrayRef<int64_t> newShape = oldShape;
6836
6837 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6838 ArrayRef<bool> newScalableDims = oldScalableDims;
6839
6840 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6841 newShape = newShape.drop_back(1);
6842 newScalableDims = newScalableDims.drop_back(1);
6843 }
6844
6845 // Make sure we have at least 1 dimension.
6846 // TODO: Add support for 0-D vectors.
6847 if (newShape.empty()) {
6848 newShape = oldShape.take_back();
6849 newScalableDims = oldScalableDims.take_back();
6850 }
6851
6852 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6853}
6854
6855/// Folds qualifying shape_cast(create_mask) into a new create_mask
6856///
6857/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6858/// dimension. If the input vector comes from `vector.create_mask` for which
6859/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6860/// to fold shape_cast into create_mask.
6861///
6862/// BEFORE:
6863/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6864/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6865/// AFTER:
6866/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6867class ShapeCastCreateMaskFolderTrailingOneDim final
6868 : public OpRewritePattern<ShapeCastOp> {
6869public:
6870 using Base::Base;
6871
6872 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6873 PatternRewriter &rewriter) const override {
6874 Value shapeOpSrc = shapeOp->getOperand(0);
6875 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6876 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6877 if (!createMaskOp && !constantMaskOp)
6878 return failure();
6879
6880 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6881 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6882
6883 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6884 if (newVecType != shapeOpResTy)
6885 return failure();
6886
6887 auto numDimsToDrop =
6888 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6889
6890 // No unit dims to drop
6891 if (!numDimsToDrop)
6892 return failure();
6893
6894 if (createMaskOp) {
6895 auto maskOperands = createMaskOp.getOperands();
6896 auto numMaskOperands = maskOperands.size();
6897
6898 // Check every mask dim size to see whether it can be dropped
6899 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6900 --i) {
6901 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6902 if (!constant || (constant.value() != 1))
6903 return failure();
6904 }
6905 SmallVector<Value> newMaskOperands =
6906 maskOperands.drop_back(numDimsToDrop);
6907
6908 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6909 newMaskOperands);
6910 return success();
6911 }
6912
6913 if (constantMaskOp) {
6914 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6915 auto numMaskOperands = maskDimSizes.size();
6916
6917 // Check every mask dim size to see whether it can be dropped
6918 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6919 --i) {
6920 if (maskDimSizes[i] != 1)
6921 return failure();
6922 }
6923
6924 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6925 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6926 newMaskOperands);
6927 return success();
6928 }
6929
6930 return failure();
6931 }
6932};
6933
6934// vector.broadcast has two distinct semantic modes: duplication across leading
6935// dimensions, and stretching across inner dimensions. This helper returns the
6936// product of the inner-dimension stretching factors.
6937int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
6938 ArrayRef<int64_t> dstShape) {
6939 int stretchingFactor = 1;
6940 int numLeadingDims = dstShape.size() - srcShape.size();
6941 for (int i = 0, e = srcShape.size(); i < e; i++) {
6942 int64_t dstDim = dstShape[numLeadingDims + i];
6943 if (srcShape[i] == 1 && dstDim != 1) {
6944 stretchingFactor *= dstDim;
6945 }
6946 }
6947 return stretchingFactor;
6948}
6949
6950/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
6951class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6952public:
6953 using Base::Base;
6954
6955 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6956 PatternRewriter &rewriter) const override {
6957 auto broadcastOp =
6958 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6959 if (!broadcastOp)
6960 return failure();
6961
6962 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6963 bool srcIsScalar = !srcVectorType;
6964
6965 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6966 // Example
6967 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6968 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6969 // to
6970 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6971 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6972 ArrayRef<int64_t> dstShape = dstVectorType.getShape();
6973 ArrayRef<int64_t> srcShape =
6974 srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
6975 ArrayRef<int64_t> broadcastShape =
6976 broadcastOp.getResultVectorType().getShape();
6977
6978 if (!srcIsScalar) {
6979 if (isBroadcastableTo(srcVectorType, dstVectorType) !=
6980 BroadcastableToResult::Success) {
6981 return failure();
6982 }
6983 // Avoid folding if this would result in switching between the two
6984 // distinct semantic modes of vector.broadcast (duplication vs
6985 // stretching). See https://github.com/llvm/llvm-project/issues/190614.
6986 // This is detected by a change in the stretching factor. However if the
6987 // source has a single element, there is no ambiguity.
6988 if (srcVectorType.getNumElements() != 1) {
6989 if (getBroadcastStretchingFactor(srcShape, dstShape) !=
6990 getBroadcastStretchingFactor(srcShape, broadcastShape)) {
6991 return failure();
6992 }
6993 }
6994 }
6995
6996 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shapeCastOp, dstVectorType,
6997 broadcastOp.getSource());
6998 return success();
6999 }
7000};
7001
7002/// Pattern to rewrite Y = ShapeCast(FromElements(X)) as Y = FromElements(X)
7003///
7004/// BEFORE:
7005/// %1 = vector.from_elements %c1, %c2, %c3 : vector<3xf32>
7006/// %2 = vector.shape_cast %1 : vector<3xf32> to vector<1x3xf32>
7007/// AFTER:
7008/// %2 = vector.from_elements %c1, %c2, %c3 : vector<1x3xf32>
7009///
7010/// Note: this transformation is implemented as an OpRewritePattern, not as a
7011/// fold, because we have to create new op FromElementsOp with updated result
7012/// type. This cannot be done with a fold, because fold cannot create new ops
7013/// and the existing FromElementsOp result type differs from the ShapeCastOp
7014/// result type. Mutating the FromElementsOp (not root op) would violate the
7015/// fold contract and break other users.
7016class FoldShapeCastOfFromElements final : public OpRewritePattern<ShapeCastOp> {
7017public:
7018 using Base::Base;
7019
7020 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
7021 PatternRewriter &rewriter) const override {
7022 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
7023 if (!fromElements)
7024 return failure();
7025
7026 rewriter.replaceOpWithNewOp<FromElementsOp>(
7027 shapeCastOp, shapeCastOp.getResultVectorType(),
7028 fromElements.getElements());
7029 return success();
7030 }
7031};
7032
7033} // namespace
7034
7035void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
7036 MLIRContext *context) {
7037 results.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
7038 FoldShapeCastOfFromElements>(context);
7039}
7040
7041//===----------------------------------------------------------------------===//
7042// VectorBitCastOp
7043//===----------------------------------------------------------------------===//
7044
7045LogicalResult BitCastOp::verify() {
7046 auto sourceVectorType = getSourceVectorType();
7047 auto resultVectorType = getResultVectorType();
7048
7049 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
7050 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
7051 return emitOpError("dimension size mismatch at: ") << i;
7052 }
7053
7054 DataLayout dataLayout = DataLayout::closest(*this);
7055 auto sourceElementBits =
7056 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
7057 auto resultElementBits =
7058 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
7059
7060 if (sourceVectorType.getRank() == 0) {
7061 if (sourceElementBits != resultElementBits)
7062 return emitOpError("source/result bitwidth of the 0-D vector element "
7063 "types must be equal");
7064 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
7065 resultElementBits * resultVectorType.getShape().back()) {
7066 return emitOpError(
7067 "source/result bitwidth of the minor 1-D vectors must be equal");
7068 }
7069
7070 return success();
7071}
7072
7073OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
7074 // Nop cast.
7075 if (getSource().getType() == getResult().getType())
7076 return getSource();
7077
7078 // Canceling bitcasts.
7079 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
7080 if (getResult().getType() == otherOp.getSource().getType())
7081 return otherOp.getSource();
7082
7083 setOperand(otherOp.getSource());
7084 return getResult();
7085 }
7086
7087 Attribute sourceConstant = adaptor.getSource();
7088 if (!sourceConstant)
7089 return {};
7090
7091 Type srcElemType = getSourceVectorType().getElementType();
7092 Type dstElemType = getResultVectorType().getElementType();
7093
7094 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
7095 if (floatPack.isSplat()) {
7096 auto splat = floatPack.getSplatValue<FloatAttr>();
7097
7098 // Casting fp16 into fp32.
7099 if (srcElemType.isF16() && dstElemType.isF32()) {
7100 uint32_t bits = static_cast<uint32_t>(
7101 splat.getValue().bitcastToAPInt().getZExtValue());
7102 // Duplicate the 16-bit pattern.
7103 bits = (bits << 16) | (bits & 0xffff);
7104 APInt intBits(32, bits);
7105 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
7106 return DenseElementsAttr::get(getResultVectorType(), floatBits);
7107 }
7108 }
7109 }
7110
7111 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
7112 if (intPack.isSplat()) {
7113 auto splat = intPack.getSplatValue<IntegerAttr>();
7114
7115 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.isIntOrFloat()) {
7116 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
7117 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
7118
7119 // Casting to a larger integer bit width.
7120 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
7121 APInt intBits = splat.getValue().zext(dstBitWidth);
7122
7123 // Duplicate the lower width element.
7124 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
7125 intBits = (intBits << srcBitWidth) | intBits;
7126 return DenseElementsAttr::get(getResultVectorType(), intBits);
7127 }
7128 }
7129 }
7130 }
7131
7132 return {};
7133}
7134
7135std::optional<SmallVector<int64_t, 4>> BitCastOp::getShapeForUnroll() {
7136 return llvm::to_vector<4>(getResultVectorType().getShape());
7137}
7138
7139//===----------------------------------------------------------------------===//
7140// TypeCastOp
7141//===----------------------------------------------------------------------===//
7142
7143static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
7144 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
7145 SmallVector<int64_t, 8> res(memRefType.getShape());
7146 if (vectorType)
7147 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
7148 return res;
7149}
7150
7151/// Build the canonical memRefType with a single vector.
7152/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
7153void TypeCastOp::build(OpBuilder &builder, OperationState &result,
7154 Value source) {
7155 result.addOperands(source);
7156 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
7157 VectorType vectorType =
7158 VectorType::get(extractShape(memRefType),
7160 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
7161 memRefType.getMemorySpace()));
7162}
7163
7164LogicalResult TypeCastOp::verify() {
7165 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
7166 if (!canonicalType.getLayout().isIdentity())
7167 return emitOpError("expects operand to be a memref with identity layout");
7168 if (!getResultMemRefType().getLayout().isIdentity())
7169 return emitOpError("expects result to be a memref with identity layout");
7170 if (getResultMemRefType().getMemorySpace() !=
7171 getMemRefType().getMemorySpace())
7172 return emitOpError("expects result in same memory space");
7173
7174 auto sourceType = getMemRefType();
7175 auto resultType = getResultMemRefType();
7176 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
7178 return emitOpError(
7179 "expects result and operand with same underlying scalar type: ")
7180 << resultType;
7181 if (extractShape(sourceType) != extractShape(resultType))
7182 return emitOpError(
7183 "expects concatenated result and operand shapes to be equal: ")
7184 << resultType;
7185 return success();
7186}
7187
7188//===----------------------------------------------------------------------===//
7189// TransposeOp
7190//===----------------------------------------------------------------------===//
7191
7192void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
7193 Value vector, ArrayRef<int64_t> permutation) {
7194 VectorType vt = llvm::cast<VectorType>(vector.getType());
7195 SmallVector<int64_t, 4> transposedShape(vt.getRank());
7196 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
7197 for (unsigned i = 0; i < permutation.size(); ++i) {
7198 transposedShape[i] = vt.getShape()[permutation[i]];
7199 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
7200 }
7201
7202 result.addOperands(vector);
7203 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
7204 transposedScalableDims));
7205 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
7206 builder.getDenseI64ArrayAttr(permutation));
7207}
7208
7209OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
7210 // Eliminate splat constant transpose ops.
7211 if (auto splat =
7212 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
7213 return splat.reshape(getResultVectorType());
7214
7215 // Eliminate poison transpose ops.
7216 if (matchPattern(adaptor.getVector(), ub::m_Poison()))
7217 return ub::PoisonAttr::get(getContext());
7218
7219 // Eliminate identity transposes, and more generally any transposes that
7220 // preserves the shape without permuting elements.
7221 //
7222 // Examples of what to fold:
7223 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
7224 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
7225 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
7226 //
7227 // Example of what NOT to fold:
7228 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
7229 //
7230 if (getSourceVectorType() == getResultVectorType() &&
7231 isOrderPreserving(*this))
7232 return getVector();
7233
7234 return {};
7235}
7236
7237LogicalResult vector::TransposeOp::verify() {
7238 VectorType vectorType = getSourceVectorType();
7239 VectorType resultType = getResultVectorType();
7240 int64_t rank = resultType.getRank();
7241 if (vectorType.getRank() != rank)
7242 return emitOpError("vector result rank mismatch: ") << rank;
7243 // Verify transposition array.
7244 ArrayRef<int64_t> perm = getPermutation();
7245 int64_t size = perm.size();
7246 if (rank != size)
7247 return emitOpError("transposition length mismatch: ") << size;
7248 SmallVector<bool, 8> seen(rank, false);
7249 for (const auto &ta : llvm::enumerate(perm)) {
7250 if (ta.value() < 0 || ta.value() >= rank)
7251 return emitOpError("transposition index out of range: ") << ta.value();
7252 if (seen[ta.value()])
7253 return emitOpError("duplicate position index: ") << ta.value();
7254 seen[ta.value()] = true;
7255 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
7256 return emitOpError("dimension size mismatch at: ") << ta.value();
7257 }
7258 return success();
7259}
7260
7261std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
7262 return llvm::to_vector<4>(getResultVectorType().getShape());
7263}
7264
7265void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7266 SetIntRangeFn setResultRanges) {
7267 setResultRanges(getResult(), argRanges.front());
7268}
7269
7270namespace {
7271
7272// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
7273class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
7274public:
7275 using Base::Base;
7276
7277 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7278 PatternRewriter &rewriter) const override {
7279 // Composes two permutations: result[i] = permutation1[permutation2[i]].
7280 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7281 ArrayRef<int64_t> permutation2) {
7282 SmallVector<int64_t, 4> result;
7283 for (auto index : permutation2)
7284 result.push_back(permutation1[index]);
7285 return result;
7286 };
7287
7288 // Return if the input of 'transposeOp' is not defined by another transpose.
7289 vector::TransposeOp parentTransposeOp =
7290 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7291 if (!parentTransposeOp)
7292 return failure();
7293
7294 SmallVector<int64_t, 4> permutation = composePermutations(
7295 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7296 // Replace 'transposeOp' with a new transpose operation.
7297 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
7298 transposeOp, transposeOp.getResult().getType(),
7299 parentTransposeOp.getVector(), permutation);
7300 return success();
7301 }
7302};
7303
7304/// Replace transpose(splat-like(v)) with broadcast(v)
7305class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
7306public:
7307 using Base::Base;
7308
7309 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7310 PatternRewriter &rewriter) const override {
7311 Value splat = getScalarSplatSource(transposeOp.getVector());
7312 if (!splat)
7313 return failure();
7314
7315 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
7316 transposeOp, transposeOp.getResultVectorType(), splat);
7317 return success();
7318 }
7319};
7320
7321/// Folds transpose(create_mask) into a new transposed create_mask.
7322class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
7323public:
7324 using Base::Base;
7325
7326 LogicalResult matchAndRewrite(TransposeOp transpOp,
7327 PatternRewriter &rewriter) const override {
7328 Value transposeSrc = transpOp.getVector();
7329 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
7330 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
7331 if (!createMaskOp && !constantMaskOp)
7332 return failure();
7333
7334 // Get the transpose permutation and apply it to the vector.create_mask or
7335 // vector.constant_mask operands.
7336 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7337
7338 if (createMaskOp) {
7339 auto maskOperands = createMaskOp.getOperands();
7340 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7341 applyPermutationToVector(newOperands, permutation);
7342
7343 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
7344 transpOp, transpOp.getResultVectorType(), newOperands);
7345 return success();
7346 }
7347
7348 // ConstantMaskOp case.
7349 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7350 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
7351
7352 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
7353 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7354 return success();
7355 }
7356};
7357
7358/// Folds transpose(shape_cast) into a new shape_cast.
7359class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
7360public:
7361 using Base::Base;
7362
7363 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7364 PatternRewriter &rewriter) const override {
7365 auto shapeCastOp =
7366 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7367 if (!shapeCastOp)
7368 return failure();
7369 if (!isOrderPreserving(transposeOp))
7370 return failure();
7371
7372 VectorType resultType = transposeOp.getType();
7373
7374 // We don't need to check isValidShapeCast at this point, because it is
7375 // guaranteed that merging the transpose into the the shape_cast is a valid
7376 // shape_cast, because the transpose just inserts/removes ones.
7377
7378 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
7379 shapeCastOp.getSource());
7380 return success();
7381 }
7382};
7383
7384/// Folds transpose(from_elements(...)) into a new from_elements with permuted
7385/// operands matching the transposed shape.
7386///
7387/// Example:
7388///
7389/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
7390/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
7391/// vector<3x2xi32>
7392///
7393/// becomes ->
7394///
7395/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
7396/// vector<3x2xi32>
7397///
7398class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
7399public:
7400 using Base::Base;
7401 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7402 PatternRewriter &rewriter) const override {
7403 auto fromElementsOp =
7404 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7405 if (!fromElementsOp)
7406 return failure();
7407
7408 VectorType srcTy = fromElementsOp.getDest().getType();
7409 VectorType dstTy = transposeOp.getType();
7410
7411 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7412 int64_t rank = srcTy.getRank();
7413
7414 // Build inverse permutation to map destination indices back to source.
7415 SmallVector<int64_t> inversePerm(rank, 0);
7416 for (int64_t i = 0; i < rank; ++i)
7417 inversePerm[permutation[i]] = i;
7418
7419 ArrayRef<int64_t> srcShape = srcTy.getShape();
7420 ArrayRef<int64_t> dstShape = dstTy.getShape();
7421 SmallVector<int64_t> srcIdx(rank, 0);
7422 SmallVector<int64_t> dstIdx(rank, 0);
7423 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
7424 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
7425
7426 auto elementsOld = fromElementsOp.getElements();
7427 SmallVector<Value> elementsNew;
7428 int64_t dstNumElements = dstTy.getNumElements();
7429 elementsNew.reserve(dstNumElements);
7430
7431 // For each element in destination row-major order, pick the corresponding
7432 // source element.
7433 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7434 // Pick the destination element index.
7435 dstIdx = delinearize(linearIdx, dstStrides);
7436 // Map the destination element index to the source element index.
7437 for (int64_t j = 0; j < rank; ++j)
7438 srcIdx[j] = dstIdx[inversePerm[j]];
7439 // Linearize the source element index.
7440 int64_t srcLin = linearize(srcIdx, srcStrides);
7441 // Add the source element to the new elements.
7442 elementsNew.push_back(elementsOld[srcLin]);
7443 }
7444
7445 rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
7446 elementsNew);
7447 return success();
7448 }
7449};
7450
7451/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
7452/// 'order preserving', where 'order preserving' means the flattened
7453/// inputs and outputs of the transpose have identical (numerical) values.
7454///
7455/// Example:
7456/// ```
7457/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
7458/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
7459/// to vector<8x1xi32>
7460/// ```
7461/// can be rewritten as the equivalent
7462/// ```
7463/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
7464/// ```
7465/// The algorithm works by partitioning dimensions into groups that can be
7466/// locally permuted while preserving order, and checks that the transpose
7467/// only permutes within these groups.
7468///
7469/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
7470/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
7471/// broadcasting from 1x1x4x1x1x7.
7472/// ^^^ ^ ^^^ ^
7473/// groups: 0 1 2 3
7474/// Order preserving permutations for this example are ones that only permute
7475/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
7476class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
7477public:
7478 using Base::Base;
7479 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7480 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7481
7482 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7483 PatternRewriter &rewriter) const override {
7484
7485 vector::BroadcastOp broadcast =
7486 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
7487 if (!broadcast) {
7488 return rewriter.notifyMatchFailure(transpose,
7489 "not preceded by a broadcast");
7490 }
7491
7492 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
7493 VectorType outputType = transpose.getResultVectorType();
7494
7495 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
7496 bool inputIsScalar = !inputType;
7497 if (inputIsScalar) {
7498 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7499 broadcast.getSource());
7500 return success();
7501 }
7502
7503 ArrayRef<int64_t> permutation = transpose.getPermutation();
7504 ArrayRef<int64_t> inputShape = inputType.getShape();
7505 int64_t inputRank = inputType.getRank();
7506 int64_t outputRank = transpose.getType().getRank();
7507 int64_t deltaRank = outputRank - inputRank;
7508
7509 int low = 0;
7510 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7511 bool notOne = inputShape[inputIndex] != 1;
7512 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7513 bool groupEndFound = notOne || prevNotOne;
7514 if (groupEndFound) {
7515 int high = inputIndex + deltaRank;
7516 // Return failure if not all permutation destinations for indices in
7517 // [low, high) are in [low, high), i.e. the permutation is not local to
7518 // the group.
7519 for (int i = low; i < high; ++i) {
7520 if (permutation[i] < low || permutation[i] >= high) {
7521 return rewriter.notifyMatchFailure(
7522 transpose, "permutation not local to group");
7523 }
7524 }
7525 low = high;
7526 }
7527 }
7528
7529 // We don't need to check the final group [low, outputRank) because if it is
7530 // not locally bound, there must be a preceding group that already failed
7531 // the check (impossible to have just 1 non-locally bound group).
7532
7533 // The preceding logic also ensures that at this point, the output of the
7534 // transpose is definitely broadcastable from the input shape, assert so:
7535 assert(vector::isBroadcastableTo(inputType, outputType) ==
7536 vector::BroadcastableToResult::Success &&
7537 "not broadcastable directly to transpose output");
7538
7539 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7540 broadcast.getSource());
7541
7542 return success();
7543 }
7544};
7545
7546} // namespace
7547
7548void vector::TransposeOp::getCanonicalizationPatterns(
7549 RewritePatternSet &results, MLIRContext *context) {
7550 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7551 FoldTransposeSplat, FoldTransposeFromElements,
7552 FoldTransposeBroadcast>(context);
7553}
7554
7555//===----------------------------------------------------------------------===//
7556// ConstantMaskOp
7557//===----------------------------------------------------------------------===//
7558
7559void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
7560 VectorType type, ConstantMaskKind kind) {
7561 assert(kind == ConstantMaskKind::AllTrue ||
7562 kind == ConstantMaskKind::AllFalse);
7563 build(builder, result, type,
7564 kind == ConstantMaskKind::AllTrue
7565 ? type.getShape()
7566 : SmallVector<int64_t>(type.getRank(), 0));
7567}
7568
7569LogicalResult ConstantMaskOp::verify() {
7570 auto resultType = llvm::cast<VectorType>(getResult().getType());
7571 // Check the corner case of 0-D vectors first.
7572 if (resultType.getRank() == 0) {
7573 if (getMaskDimSizes().size() != 1)
7574 return emitError("array attr must have length 1 for 0-D vectors");
7575 auto dim = getMaskDimSizes()[0];
7576 if (dim != 0 && dim != 1)
7577 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
7578 return success();
7579 }
7580
7581 // Verify that array attr size matches the rank of the vector result.
7582 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
7583 return emitOpError(
7584 "must specify array attr of size equal vector result rank");
7585 // Verify that each array attr element is in bounds of corresponding vector
7586 // result dimension size.
7587 auto resultShape = resultType.getShape();
7588 auto resultScalableDims = resultType.getScalableDims();
7589 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7590 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7591 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7592 return emitOpError(
7593 "array attr of size out of bounds of vector result dimension size");
7594 if (resultScalableDims[index] && maskDimSize != 0 &&
7595 maskDimSize != resultShape[index])
7596 return emitOpError(
7597 "only supports 'none set' or 'all set' scalable dimensions");
7598 }
7599 // Verify that if one mask dim size is zero, they all should be zero (because
7600 // the mask region is a conjunction of each mask dimension interval).
7601 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7602 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
7603 if (anyZeros && !allZeros)
7604 return emitOpError("expected all mask dim sizes to be zeros, "
7605 "as a result of conjunction with zero mask dim");
7606 return success();
7607}
7608
7609bool ConstantMaskOp::isAllOnesMask() {
7610 auto resultType = getVectorType();
7611 // Check the corner case of 0-D vectors first.
7612 if (resultType.getRank() == 0) {
7613 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7614 return getMaskDimSizes()[0] == 1;
7615 }
7616 for (const auto [resultSize, maskDimSize] :
7617 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7618 if (maskDimSize < resultSize)
7619 return false;
7620 }
7621 return true;
7622}
7623
7624OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7625 ArrayRef<int64_t> bounds = getMaskDimSizes();
7626 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7627
7628 auto createBoolSplat = [&](bool x) {
7629 return SplatElementsAttr::get(getVectorType(),
7631 };
7632
7633 // Check the corner case of 0-D vectors first.
7634 if (vectorSizes.empty()) {
7635 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7636 return createBoolSplat(bounds[0] == 1);
7637 }
7638 // Fold vector.constant_mask to splat if possible.
7639 if (bounds == vectorSizes)
7640 return createBoolSplat(true);
7641 if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7642 return createBoolSplat(false);
7643 return OpFoldResult();
7644}
7645
7646//===----------------------------------------------------------------------===//
7647// CreateMaskOp
7648//===----------------------------------------------------------------------===//
7649
7650void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7651 VectorType type,
7652 ArrayRef<OpFoldResult> mixedOperands) {
7653 SmallVector<Value> operands =
7654 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7655 build(builder, result, type, operands);
7656}
7657
7658LogicalResult CreateMaskOp::verify() {
7659 auto vectorType = llvm::cast<VectorType>(getResult().getType());
7660 // Verify that an operand was specified for each result vector each dimension.
7661 if (vectorType.getRank() == 0) {
7662 if (getNumOperands() != 1)
7663 return emitOpError(
7664 "must specify exactly one operand for 0-D create_mask");
7665 } else if (getNumOperands() !=
7666 llvm::cast<VectorType>(getResult().getType()).getRank()) {
7667 return emitOpError(
7668 "must specify an operand for each result vector dimension");
7669 }
7670 return success();
7671}
7672
7673namespace {
7674
7675/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7676///
7677/// Ex 1:
7678/// %c2 = arith.constant 2 : index
7679/// %c3 = arith.constant 3 : index
7680/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7681/// Becomes:
7682/// vector.constant_mask [3, 2] : vector<4x3xi1>
7683///
7684/// Ex 2:
7685/// %c_neg_1 = arith.constant -1 : index
7686/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7687/// becomes:
7688/// vector.constant_mask [0] : vector<[8]xi1>
7689///
7690/// Ex 3:
7691/// %c8 = arith.constant 8 : index
7692/// %c16 = arith.constant 16 : index
7693/// %0 = vector.vscale
7694/// %1 = arith.muli %0, %c16 : index
7695/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7696/// becomes:
7697/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7698class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7699public:
7700 using Base::Base;
7701
7702 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7703 PatternRewriter &rewriter) const override {
7704 VectorType maskType = createMaskOp.getVectorType();
7705 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7706 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7707
7708 // Special case: Rank zero shape.
7709 constexpr std::array<int64_t, 1> rankZeroShape{1};
7710 constexpr std::array<bool, 1> rankZeroScalableDims{false};
7711 if (maskType.getRank() == 0) {
7712 maskTypeDimSizes = rankZeroShape;
7713 maskTypeDimScalableFlags = rankZeroScalableDims;
7714 }
7715
7716 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7717 // collect the `constantDims` (for the ConstantMaskOp).
7718 SmallVector<int64_t, 4> constantDims;
7719 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7720 if (auto intSize = getConstantIntValue(dimSize)) {
7721 // Constant value.
7722 // If the mask dim is non-scalable this can be any value.
7723 // If the mask dim is scalable only zero (all-false) is supported.
7724 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7725 return failure();
7726 constantDims.push_back(*intSize);
7727 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7728 // Constant vscale multiple (e.g. 4 x vscale).
7729 // Must be all-true to fold to a ConstantMask.
7730 if (vscaleMultiplier < maskTypeDimSizes[i])
7731 return failure();
7732 constantDims.push_back(*vscaleMultiplier);
7733 } else {
7734 return failure();
7735 }
7736 }
7737
7738 // Clamp values to constant_mask bounds.
7739 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7740 value = std::clamp<int64_t>(value, 0, maskDimSize);
7741
7742 // If one of dim sizes is zero, set all dims to zero.
7743 if (llvm::is_contained(constantDims, 0))
7744 constantDims.assign(constantDims.size(), 0);
7745
7746 // Replace 'createMaskOp' with ConstantMaskOp.
7747 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7748 constantDims);
7749 return success();
7750 }
7751};
7752
7753} // namespace
7754
7755void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7756 MLIRContext *context) {
7757 results.add<CreateMaskFolder>(context);
7758}
7759
7760//===----------------------------------------------------------------------===//
7761// MaskOp
7762//===----------------------------------------------------------------------===//
7763
7764void MaskOp::build(
7765 OpBuilder &builder, OperationState &result, Value mask,
7766 Operation *maskableOp,
7767 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7768 assert(maskRegionBuilder &&
7769 "builder callback for 'maskRegion' must be present");
7770
7771 result.addOperands(mask);
7772 OpBuilder::InsertionGuard guard(builder);
7773 Region *maskRegion = result.addRegion();
7774 builder.createBlock(maskRegion);
7775 maskRegionBuilder(builder, maskableOp);
7776}
7777
7778void MaskOp::build(
7779 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7780 Value mask, Operation *maskableOp,
7781 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7782 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7783 maskRegionBuilder);
7784}
7785
7786void MaskOp::build(
7787 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7788 Value mask, Value passthru, Operation *maskableOp,
7789 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7790 build(builder, result, mask, maskableOp, maskRegionBuilder);
7791 if (passthru)
7792 result.addOperands(passthru);
7793 result.addTypes(resultTypes);
7794}
7795
7796ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7797 // Create the op region.
7798 result.regions.reserve(1);
7799 Region &maskRegion = *result.addRegion();
7800
7801 auto &builder = parser.getBuilder();
7802
7803 // Parse all the operands.
7804 OpAsmParser::UnresolvedOperand mask;
7805 if (parser.parseOperand(mask))
7806 return failure();
7807
7808 // Optional passthru operand.
7809 OpAsmParser::UnresolvedOperand passthru;
7810 ParseResult parsePassthru = parser.parseOptionalComma();
7811 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7812 return failure();
7813
7814 // Parse op region.
7815 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7816 return failure();
7817
7818 MaskOp::ensureTerminator(maskRegion, builder, result.location);
7819
7820 // Parse the optional attribute list.
7821 if (parser.parseOptionalAttrDict(result.attributes))
7822 return failure();
7823
7824 // Parse all the types.
7825 Type maskType;
7826 if (parser.parseColonType(maskType))
7827 return failure();
7828
7829 SmallVector<Type> resultTypes;
7830 if (parser.parseOptionalArrowTypeList(resultTypes))
7831 return failure();
7832 result.types.append(resultTypes);
7833
7834 // Resolve operands.
7835 if (parser.resolveOperand(mask, maskType, result.operands))
7836 return failure();
7837
7838 if (parsePassthru.succeeded()) {
7839 if (resultTypes.empty())
7840 return parser.emitError(
7841 parser.getNameLoc(),
7842 "expects a result if passthru operand is provided");
7843
7844 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7845 return failure();
7846 }
7847
7848 return success();
7849}
7850
7851void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7852 p << " " << getMask();
7853 if (getPassthru())
7854 p << ", " << getPassthru();
7855
7856 // Print single masked operation and skip terminator.
7857 p << " { ";
7858 Block *singleBlock = &getMaskRegion().getBlocks().front();
7859 if (singleBlock && !singleBlock->getOperations().empty())
7860 p.printCustomOrGenericOp(&singleBlock->front());
7861 p << " }";
7862
7863 p.printOptionalAttrDict(getOperation()->getAttrs());
7864
7865 p << " : " << getMask().getType();
7866 if (getNumResults() > 0)
7867 p << " -> " << getResultTypes();
7868}
7869
7870void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7871 // 1. For an empty `vector.mask`, create a default terminator.
7872 if (region.empty() || region.front().empty()) {
7873 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7874 MaskOp>::ensureTerminator(region, builder, loc);
7875 return;
7876 }
7877
7878 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7879 Block &block = region.front();
7880 if (isa<vector::YieldOp>(block.back()))
7881 return;
7882
7883 // 3. For a non-empty `vector.mask` without an explicit terminator:
7884
7885 // Create default terminator if the number of masked operations is not
7886 // one. This case will trigger a verification failure.
7887 if (block.getOperations().size() != 1) {
7888 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7889 MaskOp>::ensureTerminator(region, builder, loc);
7890 return;
7891 }
7892
7893 // Create a terminator that yields the results from the masked operation.
7894 OpBuilder opBuilder(builder.getContext());
7895 Operation *maskedOp = &block.front();
7896 opBuilder.setInsertionPointToEnd(&block);
7897 vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7898}
7899
7900LogicalResult MaskOp::verify() {
7901 // Structural checks.
7902 Block &block = getMaskRegion().getBlocks().front();
7903 if (block.getOperations().empty())
7904 return emitOpError("expects a terminator within the mask region");
7905
7906 unsigned numMaskRegionOps = block.getOperations().size();
7907 if (numMaskRegionOps > 2)
7908 return emitOpError("expects only one operation to mask");
7909
7910 // Terminator checks.
7911 auto terminator = dyn_cast<vector::YieldOp>(block.back());
7912 if (!terminator)
7913 return emitOpError("expects a terminator within the mask region");
7914
7915 if (terminator->getNumOperands() != getNumResults())
7916 return emitOpError(
7917 "expects number of results to match mask region yielded values");
7918
7919 // Empty vector.mask. Nothing else to check.
7920 if (numMaskRegionOps == 1)
7921 return success();
7922
7923 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7924 if (!maskableOp)
7925 return emitOpError("expects a MaskableOpInterface within the mask region");
7926
7927 // Result checks.
7928 if (maskableOp->getNumResults() != getNumResults())
7929 return emitOpError("expects number of results to match maskable operation "
7930 "number of results");
7931
7932 if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7933 return emitOpError("expects all the results from the MaskableOpInterface "
7934 "to match all the values returned by the terminator");
7935
7936 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7937 return emitOpError(
7938 "expects result type to match maskable operation result type");
7939
7940 if (llvm::count_if(maskableOp->getResultTypes(),
7941 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7942 return emitOpError("multiple vector results not supported");
7943
7944 // Mask checks.
7945 Type expectedMaskType = maskableOp.getExpectedMaskType();
7946 if (getMask().getType() != expectedMaskType)
7947 return emitOpError("expects a ")
7948 << expectedMaskType << " mask for the maskable operation";
7949
7950 // Passthru checks.
7951 Value passthru = getPassthru();
7952 if (passthru) {
7953 if (!maskableOp.supportsPassthru())
7954 return emitOpError(
7955 "doesn't expect a passthru argument for this maskable operation");
7956
7957 if (maskableOp->getNumResults() != 1)
7958 return emitOpError("expects result when passthru argument is provided");
7959
7960 if (passthru.getType() != maskableOp->getResultTypes()[0])
7961 return emitOpError("expects passthru type to match result type");
7962 }
7963
7964 return success();
7965}
7966
7967/// Folds empty `vector.mask` with no passthru operand and with or without
7968/// return values. For example:
7969///
7970/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7971/// vector<8xi1> -> vector<8xf32>
7972/// %1 = user_op %0 : vector<8xf32>
7973///
7974/// becomes:
7975///
7976/// %0 = user_op %a : vector<8xf32>
7977///
7978/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7979/// as it requires creating new operations.
7980
7981static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7982 SmallVectorImpl<OpFoldResult> &results) {
7983 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7984 return failure();
7985
7986 Block *block = maskOp.getMaskBlock();
7987 auto terminator = cast<vector::YieldOp>(block->front());
7988 if (terminator.getNumOperands() == 0)
7989 return failure();
7990
7991 // `vector.mask` has results, propagate the results.
7992 llvm::append_range(results, terminator.getOperands());
7993 return success();
7994}
7995
7996LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7997 SmallVectorImpl<OpFoldResult> &results) {
7998 if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7999 return success();
8000
8001 MaskFormat maskFormat = getMaskFormat(getMask());
8002 if (maskFormat != MaskFormat::AllTrue)
8003 return failure();
8004
8005 // Move maskable operation outside of the `vector.mask` region.
8006 // If there is no maskable op (empty body), the fold cannot proceed; the
8007 // canonicalizer handles this case instead.
8008 Operation *maskableOp = getMaskableOp();
8009 if (!maskableOp)
8010 return failure();
8011 maskableOp->dropAllUses();
8012 maskableOp->moveBefore(getOperation());
8013
8014 llvm::append_range(results, maskableOp->getResults());
8015 return success();
8016}
8017
8018/// Canonialize empty `vector.mask` operations that can't be handled in
8019/// `VectorMask::fold` as they require creating new operations.
8020///
8021/// Example 1: Empty `vector.mask` with passthru operand.
8022///
8023/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
8024/// vector<8xi1> -> vector<8xf32>
8025///
8026/// becomes:
8027///
8028/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
8029///
8030class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
8031 using Base::Base;
8032
8033 LogicalResult matchAndRewrite(MaskOp maskOp,
8034 PatternRewriter &rewriter) const override {
8035 if (!maskOp.isEmpty())
8036 return failure();
8037
8038 if (!maskOp.hasPassthru())
8039 return failure();
8040
8041 // arith.select with a vector condition requires the value types to be
8042 // vectors of the same shape. Since vector.mask always has a vector mask
8043 // type, bail out when any result type doesn't match the mask shape to
8044 // avoid creating invalid IR.
8045 VectorType maskType = maskOp.getMask().getType();
8046 for (Type resultType : maskOp.getResultTypes()) {
8047 auto vecResultType = dyn_cast<VectorType>(resultType);
8048 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
8049 return failure();
8050 }
8051
8052 Block *block = maskOp.getMaskBlock();
8053 auto terminator = cast<vector::YieldOp>(block->front());
8054 assert(terminator.getNumOperands() == 1 &&
8055 "expected one result when passthru is provided");
8056
8057 rewriter.replaceOpWithNewOp<arith::SelectOp>(
8058 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
8059 terminator.getOperand(0), maskOp.getPassthru());
8060
8061 return success();
8062 }
8063};
8064
8065void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
8066 MLIRContext *context) {
8067 results.add<CanonializeEmptyMaskOp>(context);
8068}
8069
8070// MaskingOpInterface definitions.
8071
8072/// Returns the operation masked by this 'vector.mask'.
8073Operation *MaskOp::getMaskableOp() {
8074 Block *block = getMaskBlock();
8075 if (block->getOperations().size() < 2)
8076 return nullptr;
8077
8078 return &block->front();
8079}
8080
8081/// Returns true if 'vector.mask' has a passthru value.
8082bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
8083
8084//===----------------------------------------------------------------------===//
8085// ScanOp
8086//===----------------------------------------------------------------------===//
8087
8088LogicalResult ScanOp::verify() {
8089 VectorType srcType = getSourceType();
8090 VectorType initialType = getInitialValueType();
8091 // Check reduction dimension < rank.
8092 int64_t srcRank = srcType.getRank();
8093 int64_t reductionDim = getReductionDim();
8094 if (reductionDim >= srcRank)
8095 return emitOpError("reduction dimension ")
8096 << reductionDim << " has to be less than " << srcRank;
8097
8098 // Check that rank(initial_value) = rank(src) - 1.
8099 int64_t initialValueRank = initialType.getRank();
8100 if (initialValueRank != srcRank - 1)
8101 return emitOpError("initial value rank ")
8102 << initialValueRank << " has to be equal to " << srcRank - 1;
8103
8104 // Check shapes of initial value and src.
8105 ArrayRef<int64_t> srcShape = srcType.getShape();
8106 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
8107 SmallVector<int64_t> expectedShape;
8108 for (int i = 0; i < srcRank; i++) {
8109 if (i != reductionDim)
8110 expectedShape.push_back(srcShape[i]);
8111 }
8112 if (!llvm::equal(initialValueShapes, expectedShape)) {
8113 return emitOpError("incompatible input/initial value shapes");
8114 }
8115
8116 // Verify supported reduction kind.
8117 Type eltType = getDestType().getElementType();
8118 if (!isSupportedCombiningKind(getKind(), eltType))
8119 return emitOpError("unsupported reduction type ")
8120 << eltType << " for kind '" << stringifyCombiningKind(getKind())
8121 << "'";
8122
8123 return success();
8124}
8125
8127 RewritePatternSet &patterns, PatternBenefit benefit) {
8128 patterns
8129 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
8130 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
8131 StridedSliceConstantMaskFolder, TransposeFolder>(
8132 patterns.getContext(), benefit);
8133}
8134
8135Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
8136 CombiningKind kind, Value v1, Value acc,
8137 arith::FastMathFlagsAttr fastmath,
8138 Value mask) {
8139 Type t1 = getElementTypeOrSelf(v1.getType());
8140 Type tAcc = getElementTypeOrSelf(acc.getType());
8141 Value result;
8142
8143 switch (kind) {
8144 case CombiningKind::ADD:
8145 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
8146 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
8147 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8148 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
8149 else
8150 llvm_unreachable("invalid value types for ADD reduction");
8151 break;
8152 case CombiningKind::AND:
8153 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8154 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
8155 break;
8156 case CombiningKind::MAXNUMF:
8157 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8158 "expected float values");
8159 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
8160 break;
8161 case CombiningKind::MAXIMUMF:
8162 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8163 "expected float values");
8164 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
8165 break;
8166 case CombiningKind::MINNUMF:
8167 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8168 "expected float values");
8169 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
8170 break;
8171 case CombiningKind::MINIMUMF:
8172 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8173 "expected float values");
8174 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
8175 break;
8176 case CombiningKind::MAXSI:
8177 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8178 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
8179 break;
8180 case CombiningKind::MINSI:
8181 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8182 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
8183 break;
8184 case CombiningKind::MAXUI:
8185 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8186 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
8187 break;
8188 case CombiningKind::MINUI:
8189 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8190 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
8191 break;
8192 case CombiningKind::MUL:
8193 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
8194 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
8195 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8196 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
8197 else
8198 llvm_unreachable("invalid value types for MUL reduction");
8199 break;
8200 case CombiningKind::OR:
8201 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8202 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
8203 break;
8204 case CombiningKind::XOR:
8205 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8206 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
8207 break;
8208 };
8209
8210 assert(result && "unknown CombiningKind");
8211 return selectPassthru(b, mask, result, acc);
8212}
8213
8214//===----------------------------------------------------------------------===//
8215// StepOp
8216//===----------------------------------------------------------------------===//
8217
8218void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
8219 SetIntRangeFn setResultRanges) {
8220 auto resultType = cast<VectorType>(getType());
8221 if (resultType.isScalable()) {
8222 return;
8223 }
8224 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
8225 APInt zero(bitwidth, 0);
8226 APInt high(bitwidth, resultType.getDimSize(0) - 1);
8227 ConstantIntRanges result = {zero, high, zero, high};
8228 setResultRanges(getResult(), result);
8229}
8230
8231namespace {
8232
8233/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
8234/// constant large enough such that the result is the same at all indices.
8235///
8236/// For example, rewrite the 'greater than' comparison below,
8237///
8238/// ```mlir
8239/// %cst = arith.constant dense<7> : vector<3xindex>
8240/// %stp = vector.step : vector<3xindex>
8241/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
8242/// ```
8243///
8244/// as,
8245///
8246/// ```mlir
8247/// %out = arith.constant dense<false> : vector<3xi1>.
8248/// ```
8249///
8250/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
8251/// is false at ALL indices we fold. If the constant was 1, then
8252/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
8253/// conservatively preferring the 'compact' vector.step representation.
8254///
8255/// Note: this folder only works for the case where the constant (`%cst` above)
8256/// is the second operand of the comparison. The arith.cmpi canonicalizer will
8257/// ensure that constants are always second (on the right).
8258struct StepCompareFolder : public OpRewritePattern<StepOp> {
8259 using Base::Base;
8260
8261 LogicalResult matchAndRewrite(StepOp stepOp,
8262 PatternRewriter &rewriter) const override {
8263 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
8264
8265 for (OpOperand &use : stepOp.getResult().getUses()) {
8266 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8267 if (!cmpiOp)
8268 continue;
8269
8270 // arith.cmpi canonicalizer makes constants final operands.
8271 const unsigned stepOperandNumber = use.getOperandNumber();
8272 if (stepOperandNumber != 0)
8273 continue;
8274
8275 // Check that operand 1 is a constant.
8276 unsigned constOperandNumber = 1;
8277 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8278 std::optional<int64_t> maybeConstValue =
8279 getConstantIntValue(otherOperand);
8280 if (!maybeConstValue.has_value())
8281 continue;
8282
8283 int64_t constValue = maybeConstValue.value();
8284 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8285
8286 auto maybeSplat = [&]() -> std::optional<bool> {
8287 // Handle ult (unsigned less than) and uge (unsigned greater equal).
8288 if ((pred == arith::CmpIPredicate::ult ||
8289 pred == arith::CmpIPredicate::uge) &&
8290 stepSize <= constValue)
8291 return pred == arith::CmpIPredicate::ult;
8292
8293 // Handle ule and ugt.
8294 if ((pred == arith::CmpIPredicate::ule ||
8295 pred == arith::CmpIPredicate::ugt) &&
8296 stepSize - 1 <= constValue) {
8297 return pred == arith::CmpIPredicate::ule;
8298 }
8299
8300 // Handle eq and ne.
8301 if ((pred == arith::CmpIPredicate::eq ||
8302 pred == arith::CmpIPredicate::ne) &&
8303 stepSize <= constValue)
8304 return pred == arith::CmpIPredicate::ne;
8305
8306 return std::nullopt;
8307 }();
8308
8309 if (!maybeSplat.has_value())
8310 continue;
8311
8312 rewriter.setInsertionPointAfter(cmpiOp);
8313
8314 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8315 if (!type)
8316 continue;
8317
8318 auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
8319 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8320 type, boolAttr);
8321
8322 rewriter.replaceOp(cmpiOp, splat);
8323 return success();
8324 }
8325
8326 return failure();
8327 }
8328};
8329} // namespace
8330
8331void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8332 MLIRContext *context) {
8333 results.add<StepCompareFolder>(context);
8334}
8335
8336//===----------------------------------------------------------------------===//
8337// Vector Masking Utilities
8338//===----------------------------------------------------------------------===//
8339
8340/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
8341/// as masked operation.
8342void mlir::vector::createMaskOpRegion(OpBuilder &builder,
8343 Operation *maskableOp) {
8344 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
8345 Block *insBlock = builder.getInsertionBlock();
8346 // Create a block and move the op to that block.
8347 insBlock->getOperations().splice(
8348 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
8349 YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
8350}
8351
8352/// Creates a vector.mask operation around a maskable operation. Returns the
8353/// vector.mask operation if the mask provided is valid. Otherwise, returns
8354/// the maskable operation itself.
8355Operation *mlir::vector::maskOperation(OpBuilder &builder,
8356 Operation *maskableOp, Value mask,
8357 Value passthru) {
8358 if (!mask)
8359 return maskableOp;
8360 if (passthru)
8361 return MaskOp::create(builder, maskableOp->getLoc(),
8362 maskableOp->getResultTypes(), mask, passthru,
8363 maskableOp, createMaskOpRegion);
8364 return MaskOp::create(builder, maskableOp->getLoc(),
8365 maskableOp->getResultTypes(), mask, maskableOp,
8367}
8368
8369/// Creates a vector select operation that picks values from `newValue` or
8370/// `passthru` for each result vector lane based on `mask`. This utility is used
8371/// to propagate the pass-thru value of vector.mask or for cases where only the
8372/// pass-thru value propagation is needed. VP intrinsics do not support
8373/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
8374/// usually able to match op + select patterns and fold them into a native
8375/// target instructions.
8376Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
8377 Value newValue, Value passthru) {
8378 if (!mask)
8379 return newValue;
8380
8381 return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
8382 mask, newValue, passthru);
8383}
8384
8385//===----------------------------------------------------------------------===//
8386// InterleaveOp
8387//===----------------------------------------------------------------------===//
8388
8389namespace {
8390
8391/// This folder works on the following round-trip identity:
8392/// interleave(deinterleave(x).even, deinterleave(x).odd) -> x
8393struct InterleaveDeinterleaveFolder : public OpRewritePattern<InterleaveOp> {
8394 using Base::Base;
8395
8396 LogicalResult matchAndRewrite(InterleaveOp interleaveOp,
8397 PatternRewriter &rewriter) const override {
8398 auto lhsDefOp = interleaveOp.getLhs().getDefiningOp<DeinterleaveOp>();
8399 auto rhsDefOp = interleaveOp.getRhs().getDefiningOp<DeinterleaveOp>();
8400 if (!lhsDefOp || !rhsDefOp || lhsDefOp != rhsDefOp)
8401 return failure();
8402 for (auto [idx, operand] : llvm::enumerate(interleaveOp.getOperands())) {
8403 if (cast<OpResult>(operand).getResultNumber() != idx)
8404 return failure();
8405 }
8406 rewriter.replaceOp(interleaveOp, lhsDefOp.getSource());
8407 return success();
8408 }
8409};
8410} // namespace
8411
8412void InterleaveOp::getCanonicalizationPatterns(RewritePatternSet &results,
8413 MLIRContext *context) {
8414 results.add<InterleaveDeinterleaveFolder>(context);
8415}
8416
8417std::optional<SmallVector<int64_t, 4>> InterleaveOp::getShapeForUnroll() {
8418 return llvm::to_vector<4>(getResultVectorType().getShape());
8419}
8420
8421//===----------------------------------------------------------------------===//
8422// DeinterleaveOp
8423//===----------------------------------------------------------------------===//
8424
8425std::optional<SmallVector<int64_t, 4>> DeinterleaveOp::getShapeForUnroll() {
8426 return llvm::to_vector<4>(getResultVectorType().getShape());
8427}
8428
8429//===----------------------------------------------------------------------===//
8430// TableGen'd op method definitions
8431//===----------------------------------------------------------------------===//
8432
8433#define GET_ATTRDEF_CLASSES
8434#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8435
8436#define GET_OP_CLASSES
8437#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition VectorOps.cpp:75
static OpFoldResult foldShuffleIdentityMask(ShuffleOp op)
Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static OpFoldResult foldShufflePoisonInputs(MLIRContext *context, Attribute v1Attr, Attribute v2Attr)
Fold shuffle poison, poison -> poison.
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
Definition VectorOps.cpp:65
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldShuffleConstantInputs(ShuffleOp op, Attribute v1Attr, Attribute v2Attr)
Fold a shuffle of constant 1-D inputs by evaluating the mask.
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static OpFoldResult foldShufflePoisonOperandToMask(ShuffleOp op)
If a shuffle operand is poison, replace all mask indices that reference it with kPoisonIndex.
static LogicalResult foldSize1TransferPermutationMap(TransferOp op)
When the vector type is vector<1xT>, the permutation map is irrelevant: the single vector lane always...
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
#define mul(a, b)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
bool empty()
Definition Block.h:158
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
iterator begin()
Definition Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:286
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:275
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:323
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:567
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Value getOperand(unsigned idx)
Definition Operation.h:375
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:859
void setOperand(unsigned idx, Value value)
Definition Operation.h:376
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
operand_type_range getOperandTypes()
Definition Operation.h:422
result_type_range getResultTypes()
Definition Operation.h:453
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:440
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isF32() const
Definition Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:114
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
bool hasNegativeStaticStride(MemRefType memRefTy)
Returns true if any stride of memRefTy is statically known to be negative.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition VectorOps.h:64
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition VectorOps.h:72
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const