MLIR 23.0.0git
VectorOps.cpp
Go to the documentation of this file.
1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
15
26#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Builders.h"
32#include "mlir/IR/IRMapping.h"
36#include "mlir/IR/ValueRange.h"
39#include "mlir/Support/LLVM.h"
41#include "llvm/ADT/ArrayRef.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/StringSet.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/Support/Casting.h"
47
48#include <cassert>
49#include <cstdint>
50#include <numeric>
51
52#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
53// Pull in all enum type and utility function definitions.
54#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
55
56using namespace mlir;
57using namespace mlir::vector;
58
59/// Helper enum to classify mask value.
60enum class MaskFormat {
64};
65
66/// Helper method to classify a mask value. Currently, the method
67/// looks "under the hood" of a constant value with dense attributes
68/// and a constant mask operation (since the client may be called at
69/// various stages during progressive lowering).
71 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
72 // Inspect constant dense values. We count up for bits that
73 // are set, count down for bits that are cleared, and bail
74 // when a mix is detected.
75 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
76 int64_t val = 0;
77 for (bool b : denseElts.getValues<bool>())
78 if (b && val >= 0)
79 val++;
80 else if (!b && val <= 0)
81 val--;
82 else
84 if (val > 0)
86 if (val < 0)
88 }
89 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
90 // Inspect constant mask index. If the index exceeds the
91 // dimension size, all bits are set. If the index is zero
92 // or less, no bits are set.
93 ArrayRef<int64_t> masks = m.getMaskDimSizes();
94 auto shape = m.getType().getShape();
95 bool allTrue = true;
96 bool allFalse = true;
97 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
98 if (maskIdx < dimSize)
99 allTrue = false;
100 if (maskIdx > 0)
101 allFalse = false;
102 }
103 if (allTrue)
104 return MaskFormat::AllTrue;
105 if (allFalse)
107 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
108 // Finds all-false create_masks. An all-true create_mask requires all
109 // dims to be constants, so that'll be folded to a constant_mask, then
110 // detected in the constant_mask case.
111 auto maskOperands = m.getOperands();
112 for (Value operand : maskOperands) {
113 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
114 int64_t dimSize =
115 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
116 if (dimSize <= 0)
118 }
119 }
120 return MaskFormat::Unknown;
121 }
122 return MaskFormat::Unknown;
123}
124
125/// Default callback to build a region with a 'vector.yield' terminator with no
126/// arguments.
128 vector::YieldOp::create(builder, loc);
129}
130
131// Helper for verifying combining kinds in contractions and reductions.
132static bool isSupportedCombiningKind(CombiningKind combiningKind,
133 Type elementType) {
134 switch (combiningKind) {
135 case CombiningKind::ADD:
136 case CombiningKind::MUL:
137 return elementType.isIntOrIndexOrFloat();
138 case CombiningKind::MINUI:
139 case CombiningKind::MINSI:
140 case CombiningKind::MAXUI:
141 case CombiningKind::MAXSI:
142 case CombiningKind::AND:
143 case CombiningKind::OR:
144 case CombiningKind::XOR:
145 return elementType.isIntOrIndex();
146 case CombiningKind::MINNUMF:
147 case CombiningKind::MAXNUMF:
148 case CombiningKind::MINIMUMF:
149 case CombiningKind::MAXIMUMF:
150 return llvm::isa<FloatType>(elementType);
151 }
152 return false;
153}
154
155/// Returns the effective rank of the vector to read/write for Xfer Ops
156///
157/// When the element type of the shaped type is _a scalar_, this will simply
158/// return the rank of the vector ( the result for xfer_read or the value to
159/// store for xfer_write).
160///
161/// When the element type of the base shaped type is _a vector_, returns the
162/// difference between the original vector type and the element type of the
163/// shaped type.
164///
165/// EXAMPLE 1 (element type is _a scalar_):
166/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167/// - shapedType.getElementType() = f32 (rank 0)
168/// - vectorType.getRank() = 2
169/// - Result = 2 - 0 = 2
170///
171/// EXAMPLE 2 (element type is _a vector_):
172/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
173/// - shapedType.getElementType() = vector<20xf32> (rank 1)
174/// - vectorType.getRank() = 1
175/// - Result = 1 - 1 = 0
176///
177/// This is used to determine the number of minor dimensions for identity maps
178/// in vector transfer Ops.
179static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
180 VectorType vectorType) {
181 unsigned elementVectorRank = 0;
182 VectorType elementVectorType =
183 llvm::dyn_cast<VectorType>(shapedType.getElementType());
184 if (elementVectorType)
185 elementVectorRank += elementVectorType.getRank();
186 return vectorType.getRank() - elementVectorRank;
187}
188
190 VectorType vectorType) {
191 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
192 // TODO: replace once we have 0-d vectors.
193 if (shapedType.getRank() == 0 &&
194 vectorType.getShape() == ArrayRef<int64_t>{1})
195 return AffineMap::get(
196 /*numDims=*/0, /*numSymbols=*/0,
197 getAffineConstantExpr(0, shapedType.getContext()));
199 shapedType.getRank(),
200 getEffectiveVectorRankForXferOp(shapedType, vectorType),
201 shapedType.getContext());
202}
203
204/// Check if `write` is of a constant splat and the masked `read` is padded with
205/// the same splat value -- meaning it could be the same value as the initial
206/// constant splat.
207static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
208 vector::TransferReadOp read) {
209 auto readMask = read.getMask();
210 auto writeMask = write.getMask();
211 // Check if the masks are consistent. The splat value could be the same if the
212 // read is masked (and padded with the splat value), and the write is unmasked
213 // or has the same mask. Note this does not allow the case where the write is
214 // masked and the read is unmasked, as then the read could be of more elements
215 // than the write (which may not be the same value).
216 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
217 if (!couldBeSameSplat)
218 return false;
219 // Check for constant splat (as the source of the write).
220 DenseElementsAttr splatAttr;
221 if (!matchPattern(write.getVector(),
222 m_Constant<DenseElementsAttr>(&splatAttr)) ||
223 !splatAttr.isSplat()) {
224 return false;
225 }
226 // The padding of the read and the constant splat value must be the same.
227 Attribute padAttr;
228 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
229 return false;
230 return padAttr == splatAttr.getSplatValue<Attribute>();
231}
232
233bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
234 vector::TransferReadOp read) {
235 return !defWrite.hasOutOfBoundsDim() &&
236 defWrite.getIndices() == read.getIndices() &&
237 defWrite.getVectorType() == read.getVectorType() &&
238 defWrite.getPermutationMap() == read.getPermutationMap() &&
239 ((!defWrite.getMask() && !read.getMask()) ||
241}
242
243bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
244 vector::TransferWriteOp priorWrite) {
245 return priorWrite.getIndices() == write.getIndices() &&
246 priorWrite.getMask() == write.getMask() &&
247 priorWrite.getVectorType() == write.getVectorType() &&
248 priorWrite.getPermutationMap() == write.getPermutationMap();
249}
250
252 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
253 bool testDynamicValueUsingBounds) {
254 // For simplicity only look at transfer of same type.
255 if (transferA.getVectorType() != transferB.getVectorType())
256 return false;
257 unsigned rankOffset = transferA.getLeadingShapedRank();
258 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
259 Value indexA = transferA.getIndices()[i];
260 Value indexB = transferB.getIndices()[i];
261 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
262 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
263
264 if (i < rankOffset) {
265 // For leading dimensions, if we can prove that index are different we
266 // know we are accessing disjoint slices.
267 if (cstIndexA.has_value() && cstIndexB.has_value()) {
268 if (*cstIndexA != *cstIndexB)
269 return true;
270 continue;
271 }
272 if (testDynamicValueUsingBounds) {
273 // First try to see if we can fully compose and simplify the affine
274 // expression as a fast track.
275 FailureOr<uint64_t> delta =
277 if (succeeded(delta) && *delta != 0)
278 return true;
279
280 FailureOr<bool> testEqual =
282 if (succeeded(testEqual) && !testEqual.value())
283 return true;
284 }
285 } else {
286 // For this dimension, we slice a part of the memref we need to make sure
287 // the intervals accessed don't overlap.
288 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
289 if (cstIndexA.has_value() && cstIndexB.has_value()) {
290 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
291 if (distance >= vectorDim)
292 return true;
293 continue;
294 }
295 if (testDynamicValueUsingBounds) {
296 // First try to see if we can fully compose and simplify the affine
297 // expression as a fast track.
298 FailureOr<int64_t> delta =
300 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
301 return true;
302
303 FailureOr<int64_t> computeDelta =
305 if (succeeded(computeDelta)) {
306 if (std::abs(computeDelta.value()) >= vectorDim)
307 return true;
308 }
309 }
310 }
311 }
312 return false;
313}
314
315bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
316 VectorTransferOpInterface transferB,
317 bool testDynamicValueUsingBounds) {
318 if (transferA.getBase() != transferB.getBase())
319 return false;
320 return isDisjointTransferIndices(transferA, transferB,
321 testDynamicValueUsingBounds);
322}
323
324// Helper to iterate over n-D vector slice elements. Calculate the next
325// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
326// Modifies the `position` in place. Returns a failure when `position` becomes
327// the end position.
328static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
330 ArrayRef<int64_t> offsets) {
331 for (auto [posInDim, dimSize, offsetInDim] :
332 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
333 ++posInDim;
334 if (posInDim < dimSize + offsetInDim)
335 return success();
336
337 // Carry the overflow to the next loop iteration.
338 posInDim = offsetInDim;
339 }
340
341 return failure();
342}
343
344/// Returns the integer numbers in `values`. `values` are expected to be
345/// constant operations.
348 llvm::transform(values, std::back_inserter(ints), [](Value value) {
349 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
350 assert(constOp && "Unexpected non-constant index");
351 return constOp.value();
352 });
353 return ints;
354}
355
356/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
357/// be constant operations.
360 llvm::transform(
361 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
362 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
363 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
364 });
365 return ints;
366}
367
368/// Convert `foldResults` into Values. Integer attributes are converted to
369/// constant op.
371 ArrayRef<OpFoldResult> foldResults) {
372 SmallVector<Value> values;
373 llvm::transform(foldResults, std::back_inserter(values),
374 [&](OpFoldResult foldResult) {
375 if (auto attr = dyn_cast<Attribute>(foldResult))
377 builder, loc, cast<IntegerAttr>(attr).getInt())
378 .getResult();
379
380 return cast<Value>(foldResult);
381 });
382 return values;
383}
384
385std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
386 if (value.getDefiningOp<vector::VectorScaleOp>())
387 return 1;
388 auto mul = value.getDefiningOp<arith::MulIOp>();
389 if (!mul)
390 return {};
391 auto lhs = mul.getLhs();
392 auto rhs = mul.getRhs();
393 if (lhs.getDefiningOp<vector::VectorScaleOp>())
394 return getConstantIntValue(rhs);
395 if (rhs.getDefiningOp<vector::VectorScaleOp>())
396 return getConstantIntValue(lhs);
397 return {};
398}
399
400/// Converts numeric attributes to the expected type. Supports
401/// integer-to-integer and float-to-integer conversions. Returns the original
402/// attribute if no conversion is needed or supported.
403static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
404 // Integer-to-integer conversion
405 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
406 if (auto intType = dyn_cast<IntegerType>(expectedType)) {
407 if (intAttr.getType() != expectedType)
408 return IntegerAttr::get(expectedType, intAttr.getInt());
409 }
410 return attr;
411 }
412
413 // Float-to-integer bitcast (preserves bit representation)
414 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
415 auto intType = dyn_cast<IntegerType>(expectedType);
416 if (!intType)
417 return attr;
418
419 APFloat floatVal = floatAttr.getValue();
420 APInt intVal = floatVal.bitcastToAPInt();
421 return IntegerAttr::get(expectedType, intVal);
422 }
423
424 return attr;
425}
426
427//===----------------------------------------------------------------------===//
428// CombiningKindAttr
429//===----------------------------------------------------------------------===//
430
431namespace mlir {
432namespace vector {
433namespace detail {
435 using KeyTy = uint64_t;
436
438
439 bool operator==(const KeyTy &key) const { return value == key; }
440
442 const KeyTy &key) {
443 return new (allocator.allocate<BitmaskEnumStorage>())
445 }
446
448};
449} // namespace detail
450} // namespace vector
451} // namespace mlir
452
453//===----------------------------------------------------------------------===//
454// VectorDialect
455//===----------------------------------------------------------------------===//
456
457namespace {
458/// This class defines the interface for handling inlining with vector dialect
459/// operations.
460struct VectorInlinerInterface : public DialectInlinerInterface {
461 using DialectInlinerInterface::DialectInlinerInterface;
462
463 /// All vector dialect ops can be inlined.
464 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
465 return true;
466 }
467};
468} // namespace
469
470void VectorDialect::initialize() {
471 addAttributes<
472#define GET_ATTRDEF_LIST
473#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
474 >();
475
476 addOperations<
477#define GET_OP_LIST
478#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
479 >();
480
481 addInterfaces<VectorInlinerInterface>();
482
483 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
484 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
485 YieldOp>();
486 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
487 TransferWriteOp>();
488 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
489 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
490 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
491}
492
493/// Materialize a single constant operation from a given attribute value with
494/// the desired resultant type.
495Operation *VectorDialect::materializeConstant(OpBuilder &builder,
496 Attribute value, Type type,
497 Location loc) {
498 if (isa<ub::PoisonAttrInterface>(value))
499 return value.getDialect().materializeConstant(builder, value, type, loc);
500
501 return arith::ConstantOp::materialize(builder, value, type, loc);
502}
503
505 return builder.getIntegerType(64);
506}
507
509 ArrayRef<int64_t> values) {
510 return builder.getI64ArrayAttr(values);
511}
512
513//===----------------------------------------------------------------------===//
514// MultiDimReductionOp
515//===----------------------------------------------------------------------===//
516
517void vector::MultiDimReductionOp::build(OpBuilder &builder,
518 OperationState &result, Value source,
519 Value acc, ArrayRef<bool> reductionMask,
520 CombiningKind kind) {
521 SmallVector<int64_t> reductionDims;
522 for (const auto &en : llvm::enumerate(reductionMask))
523 if (en.value())
524 reductionDims.push_back(en.index());
525 build(builder, result, kind, source, acc, reductionDims);
526}
527
528OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
529 // Single parallel dim, this is a noop.
530 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
531 return getSource();
532 return {};
533}
534
535std::optional<SmallVector<int64_t, 4>>
536MultiDimReductionOp::getShapeForUnroll() {
537 return llvm::to_vector<4>(getSourceVectorType().getShape());
538}
539
540LogicalResult MultiDimReductionOp::verify() {
541 SmallVector<int64_t> targetShape;
542 SmallVector<bool> scalableDims;
543 Type inferredReturnType;
544 auto sourceScalableDims = getSourceVectorType().getScalableDims();
545 for (auto [dimIdx, dimSize] :
546 llvm::enumerate(getSourceVectorType().getShape()))
547 if (!llvm::any_of(getReductionDims(),
548 [dimIdx = dimIdx](int64_t reductionDimIdx) {
549 return reductionDimIdx == static_cast<int64_t>(dimIdx);
550 })) {
551 targetShape.push_back(dimSize);
552 scalableDims.push_back(sourceScalableDims[dimIdx]);
553 }
554 // TODO: update to also allow 0-d vectors when available.
555 if (targetShape.empty())
556 inferredReturnType = getSourceVectorType().getElementType();
557 else
558 inferredReturnType = VectorType::get(
559 targetShape, getSourceVectorType().getElementType(), scalableDims);
560 if (getType() != inferredReturnType)
561 return emitOpError() << "destination type " << getType()
562 << " is incompatible with source type "
563 << getSourceVectorType();
564
565 return success();
566}
567
568/// Returns the mask type expected by this operation.
569Type MultiDimReductionOp::getExpectedMaskType() {
570 auto vecType = getSourceVectorType();
571 return VectorType::get(vecType.getShape(),
572 IntegerType::get(vecType.getContext(), /*width=*/1),
573 vecType.getScalableDims());
574}
575
576namespace {
577// Only unit dimensions that are being reduced are folded. If the dimension is
578// unit, but not reduced, it is not folded, thereby keeping the output type the
579// same. If not all dimensions which are reduced are of unit dimension, this
580// transformation does nothing. This is just a generalization of
581// ElideSingleElementReduction for ReduceOp.
582struct ElideUnitDimsInMultiDimReduction
583 : public OpRewritePattern<MultiDimReductionOp> {
584 using Base::Base;
585
586 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
587 PatternRewriter &rewriter) const override {
588 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
589 for (const auto &dim : enumerate(shape)) {
590 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
591 return failure();
592 }
593
594 // Vector mask setup.
595 OpBuilder::InsertionGuard guard(rewriter);
596 Operation *rootOp;
597 Value mask;
598 if (reductionOp.isMasked()) {
599 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
600 rootOp = reductionOp.getMaskingOp();
601 mask = reductionOp.getMaskingOp().getMask();
602 } else {
603 rootOp = reductionOp;
604 }
605
606 Location loc = reductionOp.getLoc();
607 Value acc = reductionOp.getAcc();
608 Value cast;
609 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
610 if (mask) {
611 VectorType newMaskType =
612 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
613 dstVecType.getScalableDims());
614 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
615 }
616 cast = vector::ShapeCastOp::create(
617 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
618 } else {
619 // This means we are reducing all the dimensions, and all reduction
620 // dimensions are of size 1. So a simple extraction would do.
621 if (mask)
622 mask = vector::ExtractOp::create(rewriter, loc, mask);
623 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
624 }
625
626 Value result =
627 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
628 cast, /*fastmath=*/nullptr, mask);
629 rewriter.replaceOp(rootOp, result);
630 return success();
631 }
632};
633} // namespace
634
635void MultiDimReductionOp::getCanonicalizationPatterns(
636 RewritePatternSet &results, MLIRContext *context) {
637 results.add<ElideUnitDimsInMultiDimReduction>(context);
638}
639
640//===----------------------------------------------------------------------===//
641// ReductionOp
642//===----------------------------------------------------------------------===//
643
644void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
645 CombiningKind kind, Value vector,
646 arith::FastMathFlags fastMathFlags) {
647 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
648}
649
650void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
651 CombiningKind kind, Value vector, Value acc,
652 arith::FastMathFlags fastMathFlags) {
653 build(builder, result,
654 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
655 acc, fastMathFlags);
656}
657
658LogicalResult ReductionOp::verify() {
659 // Verify for 0-D and 1-D vector.
660 int64_t rank = getSourceVectorType().getRank();
661 if (rank > 1)
662 return emitOpError("unsupported reduction rank: ") << rank;
663
664 // Verify supported reduction kind.
665 Type eltType = getDest().getType();
666 if (!isSupportedCombiningKind(getKind(), eltType))
667 return emitOpError("unsupported reduction type '")
668 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
669 << "'";
670
671 return success();
672}
673
674// MaskableOpInterface methods.
675
676/// Returns the mask type expected by this operation.
677Type ReductionOp::getExpectedMaskType() {
678 auto vecType = getSourceVectorType();
679 return VectorType::get(vecType.getShape(),
680 IntegerType::get(vecType.getContext(), /*width=*/1),
681 vecType.getScalableDims());
682}
683
685 OpBuilder &builder, Location loc,
686 Value vector) {
687 switch (op) {
688 case arith::AtomicRMWKind::addf:
689 case arith::AtomicRMWKind::addi:
690 return vector::ReductionOp::create(builder, vector.getLoc(),
691 CombiningKind::ADD, vector);
692 case arith::AtomicRMWKind::mulf:
693 case arith::AtomicRMWKind::muli:
694 return vector::ReductionOp::create(builder, vector.getLoc(),
695 CombiningKind::MUL, vector);
696 case arith::AtomicRMWKind::minimumf:
697 return vector::ReductionOp::create(builder, vector.getLoc(),
698 CombiningKind::MINIMUMF, vector);
699 case arith::AtomicRMWKind::mins:
700 return vector::ReductionOp::create(builder, vector.getLoc(),
701 CombiningKind::MINSI, vector);
702 case arith::AtomicRMWKind::minu:
703 return vector::ReductionOp::create(builder, vector.getLoc(),
704 CombiningKind::MINUI, vector);
705 case arith::AtomicRMWKind::maximumf:
706 return vector::ReductionOp::create(builder, vector.getLoc(),
707 CombiningKind::MAXIMUMF, vector);
708 case arith::AtomicRMWKind::maxs:
709 return vector::ReductionOp::create(builder, vector.getLoc(),
710 CombiningKind::MAXSI, vector);
711 case arith::AtomicRMWKind::maxu:
712 return vector::ReductionOp::create(builder, vector.getLoc(),
713 CombiningKind::MAXUI, vector);
714 case arith::AtomicRMWKind::andi:
715 return vector::ReductionOp::create(builder, vector.getLoc(),
716 CombiningKind::AND, vector);
717 case arith::AtomicRMWKind::ori:
718 return vector::ReductionOp::create(builder, vector.getLoc(),
719 CombiningKind::OR, vector);
720 case arith::AtomicRMWKind::minnumf:
721 return vector::ReductionOp::create(builder, vector.getLoc(),
722 CombiningKind::MINNUMF, vector);
723 case arith::AtomicRMWKind::maxnumf:
724 return vector::ReductionOp::create(builder, vector.getLoc(),
725 CombiningKind::MAXNUMF, vector);
726 case arith::AtomicRMWKind::xori:
727 return vector::ReductionOp::create(builder, vector.getLoc(),
728 CombiningKind::XOR, vector);
729 default:
730 (void)emitOptionalError(loc, "Reduction operation type not supported");
731 break;
732 }
733 return nullptr;
734}
735
736std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
737 return llvm::to_vector<4>(getSourceVectorType().getShape());
738}
739
740namespace {
741struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
742 using Base::Base;
743
744 LogicalResult matchAndRewrite(ReductionOp reductionOp,
745 PatternRewriter &rewriter) const override {
746 // Vector mask setup.
747 OpBuilder::InsertionGuard guard(rewriter);
748 auto maskableOp =
749 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
750 Operation *rootOp;
751 Value mask;
752 if (maskableOp.isMasked()) {
753 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
754 rootOp = maskableOp.getMaskingOp();
755 mask = maskableOp.getMaskingOp().getMask();
756 } else {
757 rootOp = reductionOp;
758 }
759
760 auto vectorType = reductionOp.getSourceVectorType();
761 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
762 return failure();
763
764 Location loc = reductionOp.getLoc();
765 if (mask)
766 mask = ExtractOp::create(rewriter, loc, mask);
767 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
768
769 if (Value acc = reductionOp.getAcc())
770 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
771 result, acc,
772 reductionOp.getFastmathAttr(), mask);
773
774 rewriter.replaceOp(rootOp, result);
775 return success();
776 }
777};
778} // namespace
779
780void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
781 MLIRContext *context) {
782 results.add<ElideSingleElementReduction>(context);
783}
784
785//===----------------------------------------------------------------------===//
786// ContractionOp
787//===----------------------------------------------------------------------===//
788
789void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
791 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
792 ArrayRef<IteratorType> iteratorTypes) {
793 result.addOperands({lhs, rhs, acc});
794 result.addTypes(acc.getType());
795 result.addAttribute(
796 getIndexingMapsAttrName(result.name),
797 builder.getAffineMapArrayAttr(
798 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
799 result.addAttribute(
800 getIteratorTypesAttrName(result.name),
801 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
802 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
803 return IteratorTypeAttr::get(builder.getContext(), t);
804 }))));
805}
806
807void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
809 ArrayAttr indexingMaps,
810 ArrayAttr iteratorTypes) {
811 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
812 ContractionOp::getDefaultKind());
813}
814
815void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
817 ArrayAttr indexingMaps,
818 ArrayAttr iteratorTypes, CombiningKind kind) {
819 result.addOperands({lhs, rhs, acc});
820 result.addTypes(acc.getType());
821 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
822 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
823 result.addAttribute(getKindAttrName(result.name),
824 CombiningKindAttr::get(builder.getContext(), kind));
825}
826
827ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
833 Type resultType;
834 auto loc = parser.getCurrentLocation();
835 DictionaryAttr dictAttr;
836 // TODO: Unify linalg op attribute parsing.
837 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
838 parser.parseComma() || parser.parseOperand(rhsInfo) ||
839 parser.parseComma() || parser.parseOperand(accInfo) ||
840 parser.parseTrailingOperandList(masksInfo) ||
841 parser.parseOptionalAttrDict(result.attributes) ||
842 parser.parseColonTypeList(types) ||
843 parser.parseKeywordType("into", resultType) ||
844 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
845 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
846 parser.resolveOperand(accInfo, resultType, result.operands) ||
847 parser.addTypeToList(resultType, result.types))
848 return failure();
849 result.attributes.append(dictAttr.getValue().begin(),
850 dictAttr.getValue().end());
851
852 // Convert array of string into an array of IteratyType enums. This is needed,
853 // because tests still use the old format when 'iterator_types' attribute is
854 // represented as an array of strings.
855 // TODO: Remove this conversion once tests are fixed.
856 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
857 result.attributes.get(getIteratorTypesAttrName(result.name)));
858 if (!iteratorTypes) {
859 return parser.emitError(loc)
860 << "expected " << getIteratorTypesAttrName(result.name)
861 << " array attribute";
862 }
863
864 SmallVector<Attribute> iteratorTypeAttrs;
865
866 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
867 auto maybeIteratorType = symbolizeIteratorType(s);
868 if (!maybeIteratorType.has_value())
869 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
870
871 iteratorTypeAttrs.push_back(
872 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
873 }
874 result.attributes.set(getIteratorTypesAttrName(result.name),
875 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
876
877 if (!result.attributes.get(getKindAttrName(result.name))) {
878 result.addAttribute(
879 getKindAttrName(result.name),
880 CombiningKindAttr::get(result.getContext(),
881 ContractionOp::getDefaultKind()));
882 }
883 if (masksInfo.empty())
884 return success();
885 if (masksInfo.size() != 2)
886 return parser.emitError(parser.getNameLoc(),
887 "expected zero or exactly 2 vector mask operands");
888 auto lhsType = llvm::cast<VectorType>(types[0]);
889 auto rhsType = llvm::cast<VectorType>(types[1]);
890 auto maskElementType = parser.getBuilder().getI1Type();
891 std::array<VectorType, 2> maskTypes = {
892 VectorType::Builder(lhsType).setElementType(maskElementType),
893 VectorType::Builder(rhsType).setElementType(maskElementType)};
894 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
895 return failure();
896 return success();
897}
898
899void ContractionOp::print(OpAsmPrinter &p) {
900 // TODO: Unify printing code with linalg ops.
901 auto attrNames = getTraitAttrNames();
902 llvm::StringSet<> traitAttrsSet;
903 traitAttrsSet.insert_range(attrNames);
905 for (auto attr : (*this)->getAttrs()) {
906 if (attr.getName() == getIteratorTypesAttrName()) {
907 auto iteratorTypes =
908 llvm::cast<ArrayAttr>(attr.getValue())
909 .getAsValueRange<IteratorTypeAttr, IteratorType>();
910 // Convert IteratorType enums into the string representation. This is
911 // needed, because tests still use the old format when 'iterator_types'
912 // attribute is represented as an array of strings.
913 // TODO: Remove this conversion once tests are fixed.
914 SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
915 llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
916 return StringAttr::get(getContext(), stringifyIteratorType(t));
917 }));
918
919 attrs.emplace_back(getIteratorTypesAttrName(),
920 ArrayAttr::get(getContext(), iteratorTypeNames));
921 } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
922 attrs.push_back(attr);
923 }
924
925 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
926 p << " " << dictAttr << " " << getLhs() << ", ";
927 p << getRhs() << ", " << getAcc();
928
929 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
930 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
931 << getResultType();
932}
933
934static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
935 const std::vector<std::pair<int64_t, int64_t>> &map) {
936 for (auto &dimPair : map) {
937 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
938 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
939 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
940 return false;
941 }
942 return true;
943}
944
945static LogicalResult verifyOutputShape(
946 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
947 Type resType,
948 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
949 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
950 DenseSet<int64_t> lhsContractingDimSet;
951 DenseSet<int64_t> rhsContractingDimSet;
952 for (auto &dimPair : contractingDimMap) {
953 lhsContractingDimSet.insert(dimPair.first);
954 rhsContractingDimSet.insert(dimPair.second);
955 }
956 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
957 llvm::make_second_range(batchDimMap));
958
959 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
960 SmallVector<int64_t, 4> expectedResultDims;
961 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
962 if (lhsContractingDimSet.count(i) > 0)
963 continue;
964 expectedResultDims.push_back(lhsType.getDimSize(i));
965 }
966
967 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
968 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
969 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
970 continue;
971 expectedResultDims.push_back(rhsType.getDimSize(i));
972 }
973
974 // Verify 'expectedResultDims'.
975 if (expectedResultDims.empty()) {
976 // No batch or free dimension implies a scalar result.
977 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
978 return op.emitOpError("invalid accumulator/result vector shape");
979 } else {
980 // At least one batch or free dimension implies a vector result.
981 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
982 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
983 if (!resVectorType || !accVectorType)
984 return op.emitOpError("invalid accumulator/result vector shape");
985
986 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
987 // types fully define the result vector type. This assumes the affine maps
988 // are well-formed, which must have been verified already.
989 MLIRContext *ctx = op.getContext();
990 AffineMap lhsMap = op.getIndexingMapsArray()[0];
991 AffineMap rhsMap = op.getIndexingMapsArray()[1];
992 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
993 return op.emitOpError(
994 "expected all dimensions to be either a LHS or a RHS dimension");
996 for (auto pair :
997 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
998 VectorType v = pair.first;
999 auto map = pair.second;
1000 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1001 unsigned pos = map.getDimPosition(idx);
1002 if (!extents[pos])
1003 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
1004 }
1005 }
1006 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
1007 return op.emitOpError("expected all dimensions to get an extent as "
1008 "either a LHS or a RHS dimension");
1009
1010 AffineMap resMap = op.getIndexingMapsArray()[2];
1011 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1012 /*symbolCount=*/0, extents, ctx);
1013 // Compose the resMap with the extentsMap, which is a constant map.
1014 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1015 assert(llvm::all_of(expectedMap.getResults(),
1016 llvm::IsaPred<AffineConstantExpr>) &&
1017 "expected constant extent along all dimensions.");
1018 // Extract the expected shape and build the type.
1019 auto expectedShape = llvm::to_vector<4>(
1020 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
1021 return cast<AffineConstantExpr>(e).getValue();
1022 }));
1023 auto expected =
1024 VectorType::get(expectedShape, resVectorType.getElementType(),
1025 resVectorType.getScalableDims());
1026 if (resVectorType != expected || accVectorType != expected)
1027 return op.emitOpError(
1028 "invalid accumulator/result vector shape, expected: ")
1029 << expected;
1030 }
1031 return success();
1032}
1033
1034LogicalResult ContractionOp::verify() {
1035 VectorType lhsType = getLhsType();
1036 VectorType rhsType = getRhsType();
1037 Type accType = getAccType();
1038 Type resType = getResultType();
1039
1040 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1041 if (!lhsType.getElementType().isSignlessInteger())
1042 return emitOpError("only supports signless integer types");
1043 }
1044
1045 // Verify that an indexing map was specified for each vector operand.
1046 if (getIndexingMapsArray().size() != 3)
1047 return emitOpError("expected an indexing map for each vector operand");
1048
1049 // Verify that each index map has 'numIterators' inputs, no symbols, and
1050 // that the number of map outputs equals the rank of its associated
1051 // vector operand.
1052 unsigned numIterators = getIteratorTypes().getValue().size();
1053 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1054 auto index = it.index();
1055 auto map = it.value();
1056 if (map.getNumSymbols() != 0)
1057 return emitOpError("expected indexing map ")
1058 << index << " to have no symbols";
1059 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1060 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1061 // Verify that the map has the right number of inputs, outputs, and indices.
1062 // This also correctly accounts for (..) -> () for rank-0 results.
1063 if (map.getNumDims() != numIterators)
1064 return emitOpError("expected indexing map ")
1065 << index << " to have " << numIterators << " number of inputs";
1066 if (map.getNumResults() != rank)
1067 return emitOpError("expected indexing map ")
1068 << index << " to have " << rank << " number of outputs";
1069 if (!map.isProjectedPermutation())
1070 return emitOpError("expected indexing map ")
1071 << index << " to be a projected permutation of its inputs";
1072 }
1073
1074 auto contractingDimMap = getContractingDimMap();
1075 auto batchDimMap = getBatchDimMap();
1076
1077 // Verify at least one contracting dimension pair was specified.
1078 if (contractingDimMap.empty())
1079 return emitOpError("expected at least one contracting dimension pair");
1080
1081 // Verify contracting dimension map was properly constructed.
1082 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1083 return emitOpError("invalid contracting dimension map");
1084
1085 // Verify batch dimension map was properly constructed.
1086 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1087 return emitOpError("invalid batch dimension map");
1088
1089 // Verify 'accType' and 'resType' shape.
1090 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1091 contractingDimMap, batchDimMap)))
1092 return failure();
1093
1094 if (!getKindAttr()) {
1095 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
1096 "'vector.kind<add>')");
1097 }
1098
1099 // Verify supported combining kind.
1100 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1101 auto elementType = vectorType ? vectorType.getElementType() : resType;
1102 if (!isSupportedCombiningKind(getKind(), elementType))
1103 return emitOpError("unsupported contraction type");
1104
1105 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1106 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1107}
1108
1109// MaskableOpInterface methods.
1110
1111/// Returns the mask type expected by this operation. Mostly used for
1112/// verification purposes. It requires the operation to be vectorized."
1113Type ContractionOp::getExpectedMaskType() {
1114 auto indexingMaps = this->getIndexingMapsArray();
1115 AffineMap lhsIdxMap = indexingMaps[0];
1116 AffineMap rhsIdxMap = indexingMaps[1];
1117 VectorType lhsType = this->getLhsType();
1118 VectorType rhsType = this->getRhsType();
1119
1120 unsigned numVecDims = lhsIdxMap.getNumDims();
1121 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1122 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1123
1124 // Using the information in the indexing maps, extract the size of each
1125 // dimension in the vector.contract operation from the two input operands.
1126 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1127 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1128 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1129 lhsType.getScalableDims()[dimIdx];
1130 }
1131 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1132 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1133 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1134 rhsType.getScalableDims()[dimIdx];
1135 }
1136
1137 assert(ShapedType::isStaticShape(maskShape) &&
1138 "Mask shape couldn't be computed");
1139
1140 return VectorType::get(maskShape,
1141 IntegerType::get(lhsType.getContext(), /*width=*/1),
1142 maskShapeScalableDims);
1143}
1144
1145SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1146 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1147 getIteratorTypesAttrName(), getKindAttrName()};
1148}
1149
1151 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1152 if (targetExpr == map.getResult(i))
1153 return i;
1154 return -1;
1155}
1156
1157static std::vector<std::pair<int64_t, int64_t>>
1158getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1159 IteratorType targetIteratorType, MLIRContext *context) {
1160 std::vector<std::pair<int64_t, int64_t>> dimMap;
1161 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1162 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1163 if (iteratorType != targetIteratorType)
1164 continue;
1165 // Search lhs/rhs map results for 'targetExpr'.
1166 auto targetExpr = getAffineDimExpr(it.index(), context);
1167 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1168 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1169 if (lhsDim >= 0 && rhsDim >= 0)
1170 dimMap.emplace_back(lhsDim, rhsDim);
1171 }
1172 return dimMap;
1173}
1174
1175void ContractionOp::getIterationBounds(
1176 SmallVectorImpl<int64_t> &iterationBounds) {
1177 auto lhsShape = getLhsType().getShape();
1178 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1179 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1180 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1181 // Search lhs/rhs map results for 'targetExpr'.
1182 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1183 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1184 if (iteratorType == IteratorType::reduction) {
1185 // Get reduction dim size from lhs shape (same size in rhsShape).
1186 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1187 assert(lhsDimIndex >= 0);
1188 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1189 continue;
1190 }
1191 // Get parallel dimension size from result shape.
1192 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1193 assert(resDimIndex >= 0);
1194 assert(resVectorType != nullptr);
1195 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1196 }
1197}
1198
1199void ContractionOp::getIterationIndexMap(
1200 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1201 unsigned numMaps = getIndexingMapsArray().size();
1202 iterationIndexMap.resize(numMaps);
1203 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1204 auto index = it.index();
1205 auto map = it.value();
1206 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1207 auto dim = cast<AffineDimExpr>(map.getResult(i));
1208 iterationIndexMap[index][dim.getPosition()] = i;
1209 }
1210 }
1211}
1212
1213std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1214 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1215 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1216 getContext());
1217}
1218
1219std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1220 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1221 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1222 getContext());
1223}
1224
1225std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1227 getIterationBounds(shape);
1228 return shape;
1229}
1230
1231/// Return a fused vector::ContractionOp which represents a patterns such as:
1232///
1233/// ```mlir
1234/// %c0 = vector.constant 0: ...
1235/// %c = vector.contract %a, %b, %c0: ...
1236/// %e = add %c, %d: ...
1237/// ```
1238///
1239/// by:
1240///
1241/// ```mlir
1242/// %e = vector.contract %a, %b, %d: ...
1243/// ```
1244///
1245/// Return null if the canonicalization does not apply.
1246// TODO: This should be a folding of Add into Contract in core but while they
1247// live in different dialects, it is not possible without unnatural
1248// dependencies.
1249template <typename AddOpType>
1250struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1251 using OpRewritePattern<AddOpType>::OpRewritePattern;
1252
1253 LogicalResult matchAndRewrite(AddOpType addOp,
1254 PatternRewriter &rewriter) const override {
1255 auto canonicalize = [&](Value maybeContraction,
1256 Value otherOperand) -> vector::ContractionOp {
1257 vector::ContractionOp contractionOp =
1258 dyn_cast_or_null<vector::ContractionOp>(
1259 maybeContraction.getDefiningOp());
1260 if (!contractionOp)
1261 return vector::ContractionOp();
1262 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1263 contractionOp.getAcc().getDefiningOp())) {
1264 if (maybeZero.getValue() ==
1265 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1266 IRMapping bvm;
1267 bvm.map(contractionOp.getAcc(), otherOperand);
1268 auto newContraction =
1269 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1270 rewriter.replaceOp(addOp, newContraction.getResult());
1271 return newContraction;
1272 }
1273 }
1274 return vector::ContractionOp();
1275 };
1276
1277 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1278 vector::ContractionOp contract = canonicalize(a, b);
1279 contract = contract ? contract : canonicalize(b, a);
1280 return contract ? success() : failure();
1281 }
1282};
1283
1284void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1285 MLIRContext *context) {
1288}
1289
1290// Returns `true` if `index` is either within [0, maxIndex) or equal to
1291// `poisonValue`.
1293 int64_t maxIndex) {
1294 return index == poisonValue || (index >= 0 && index < maxIndex);
1295}
1296
1297//===----------------------------------------------------------------------===//
1298// ExtractOp
1299//===----------------------------------------------------------------------===//
1300
1301void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1302 SetIntRangeFn setResultRanges) {
1303 setResultRanges(getResult(), argRanges.front());
1304}
1305
1306void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1307 Value source) {
1308 auto vectorTy = cast<VectorType>(source.getType());
1309 build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1310}
1311
1312void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1313 Value source, int64_t position) {
1314 build(builder, result, source, ArrayRef<int64_t>{position});
1315}
1316
1317void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1318 Value source, OpFoldResult position) {
1319 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1320}
1321
1322void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1323 Value source, ArrayRef<int64_t> position) {
1324 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1325 builder.getDenseI64ArrayAttr(position));
1326}
1327
1328void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1329 Value source, ArrayRef<OpFoldResult> position) {
1330 SmallVector<int64_t> staticPos;
1331 SmallVector<Value> dynamicPos;
1332 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1333 build(builder, result, source, dynamicPos,
1334 builder.getDenseI64ArrayAttr(staticPos));
1335}
1336
1337LogicalResult
1338ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1339 ExtractOp::Adaptor adaptor,
1340 SmallVectorImpl<Type> &inferredReturnTypes) {
1341 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1342 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1343 vectorType.getRank()) {
1344 inferredReturnTypes.push_back(vectorType.getElementType());
1345 } else {
1346 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1347 vectorType.getRank());
1348 inferredReturnTypes.push_back(VectorType::get(
1349 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1350 vectorType.getScalableDims().drop_front(n)));
1351 }
1352 return success();
1353}
1354
1355bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1356 // Allow extracting 1-element vectors instead of scalars.
1357 auto isCompatible = [](TypeRange l, TypeRange r) {
1358 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1359 return vectorType && vectorType.getShape().equals({1}) &&
1360 vectorType.getElementType() == r.front();
1361 };
1362 if (l.size() == 1 && r.size() == 1 &&
1363 (isCompatible(l, r) || isCompatible(r, l)))
1364 return true;
1365 return l == r;
1366}
1367
1368LogicalResult vector::ExtractOp::verify() {
1369 if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1370 if (resTy.getRank() == 0)
1371 return emitError(
1372 "expected a scalar instead of a 0-d vector as the result type");
1373
1374 // Note: This check must come before getMixedPosition() to prevent a crash.
1375 auto dynamicMarkersCount =
1376 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1377 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1378 return emitOpError(
1379 "mismatch between dynamic and static positions (kDynamic marker but no "
1380 "corresponding dynamic position) -- this can only happen due to an "
1381 "incorrect fold/rewrite");
1382 auto position = getMixedPosition();
1383 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1384 return emitOpError(
1385 "expected position attribute of rank no greater than vector rank");
1386 for (auto [idx, pos] : llvm::enumerate(position)) {
1387 if (auto attr = dyn_cast<Attribute>(pos)) {
1388 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1390 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1391 return emitOpError("expected position attribute #")
1392 << (idx + 1)
1393 << " to be a non-negative integer smaller than the "
1394 "corresponding vector dimension or poison (-1)";
1395 }
1396 }
1397 }
1398 return success();
1399}
1400
1401template <typename IntType>
1403 return llvm::to_vector<4>(llvm::map_range(
1404 arrayAttr.getAsRange<IntegerAttr>(),
1405 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1406}
1407
1408/// Fold the result of chains of ExtractOp in place by simply concatenating the
1409/// positions.
1410static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1411 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1412 return failure();
1413
1414 // TODO: Canonicalization for dynamic position not implemented yet.
1415 if (extractOp.hasDynamicPosition())
1416 return failure();
1417
1418 SmallVector<int64_t> globalPosition;
1419 ExtractOp currentOp = extractOp;
1420 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1421 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1422 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1423 currentOp = nextOp;
1424 // TODO: Canonicalization for dynamic position not implemented yet.
1425 if (currentOp.hasDynamicPosition())
1426 return failure();
1427 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1428 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1429 }
1430 extractOp.setOperand(0, currentOp.getSource());
1431 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1432 OpBuilder b(extractOp.getContext());
1433 std::reverse(globalPosition.begin(), globalPosition.end());
1434 extractOp.setStaticPosition(globalPosition);
1435 return success();
1436}
1437
1438namespace {
1439/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1440/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1441/// Compose TransposeOp permutations as we walk back.
1442/// This helper class keeps an updated extraction position `extractPosition`
1443/// with extra trailing sentinels.
1444/// The sentinels encode the internal transposition status of the result vector.
1445/// As we iterate, extractPosition is permuted and updated.
1446class ExtractFromInsertTransposeChainState {
1447public:
1448 ExtractFromInsertTransposeChainState(ExtractOp e);
1449
1450 /// Iterate over producing insert and transpose ops until we find a fold.
1451 Value fold();
1452
1453private:
1454 /// Return true if the vector at position `a` is contained within the vector
1455 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1456 /// is a prefix of `b`.
1457 template <typename ContainerA, typename ContainerB>
1458 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1459 return a.size() <= b.size() &&
1460 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1461 }
1462
1463 /// Return true if the vector at position `a` intersects the vector at
1464 /// position `b`. Under insert/extract semantics, this is the same as equality
1465 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1466 /// Comparison is on the common prefix (i.e. zip).
1467 template <typename ContainerA, typename ContainerB>
1468 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1469 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1470 if (elemA < 0 || elemB < 0)
1471 continue;
1472 if (elemA != elemB)
1473 return false;
1474 }
1475 return true;
1476 }
1477
1478 /// Folding is only possible in the absence of an internal permutation in the
1479 /// result vector.
1480 bool canFold() {
1481 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1482 }
1483
1484 // Helper to get the next defining op of interest.
1485 void updateStateForNextIteration(Value v) {
1486 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1487 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1488 };
1489
1490 // Case 1. If we hit a transpose, just compose the map and iterate.
1491 // Invariant: insert + transpose do not change rank, we can always compose.
1492 LogicalResult handleTransposeOp();
1493
1494 // Case 2: the insert position matches extractPosition exactly, early return.
1495 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1496
1497 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1498 /// portion of the source of the insert.
1499 /// Example:
1500 /// ```
1501 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1502 /// // extractPosition == [1, 2, 3]
1503 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1504 /// // can fold to vector.extract %source[0, 3]
1505 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1506 /// ```
1507 /// To traverse through %source, we need to set the leading dims to 0 and
1508 /// drop the extra leading dims.
1509 /// This method updates the internal state.
1510 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1511
1512 /// Try to fold in place to extract(source, extractPosition) and return the
1513 /// folded result. Return null if folding is not possible (e.g. due to an
1514 /// internal transposition in the result).
1515 Value tryToFoldExtractOpInPlace(Value source);
1516
1517 ExtractOp extractOp;
1518 int64_t vectorRank;
1519 int64_t extractedRank;
1520
1521 InsertOp nextInsertOp;
1522 TransposeOp nextTransposeOp;
1523
1524 /// Sentinel values that encode the internal permutation status of the result.
1525 /// They are set to (-1, ... , -k) at the beginning and appended to
1526 /// `extractPosition`.
1527 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1528 /// ensure that there is no internal transposition.
1529 /// Internal transposition cannot be accounted for with a folding pattern.
1530 // TODO: We could relax the internal transposition with an extra transposition
1531 // operation in a future canonicalizer.
1532 SmallVector<int64_t> sentinels;
1533 SmallVector<int64_t> extractPosition;
1534};
1535} // namespace
1536
1537ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1538 ExtractOp e)
1539 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1540 extractedRank(extractOp.getNumIndices()) {
1541 assert(vectorRank >= extractedRank && "Extracted position overflow");
1542 sentinels.reserve(vectorRank - extractedRank);
1543 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1544 sentinels.push_back(-(i + 1));
1545 extractPosition.assign(extractOp.getStaticPosition().begin(),
1546 extractOp.getStaticPosition().end());
1547 llvm::append_range(extractPosition, sentinels);
1548}
1549
1550// Case 1. If we hit a transpose, just compose the map and iterate.
1551// Invariant: insert + transpose do not change rank, we can always compose.
1552LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1553 // TODO: Canonicalization for dynamic position not implemented yet.
1554 if (extractOp.hasDynamicPosition())
1555 return failure();
1556
1557 if (!nextTransposeOp)
1558 return failure();
1560 nextTransposeOp.getPermutation(), extractOp.getContext()));
1562 return success();
1563}
1564
1565// Case 2: the insert position matches extractPosition exactly, early return.
1566LogicalResult
1567ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1568 Value &res) {
1569 // TODO: Canonicalization for dynamic position not implemented yet.
1570 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1571 return failure();
1572
1573 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1574 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1575 return failure();
1576 // Case 2.a. early-exit fold.
1577 res = nextInsertOp.getValueToStore();
1578 // Case 2.b. if internal transposition is present, canFold will be false.
1579 return success(canFold());
1580}
1581
1582/// Case 3: if inserted position is a prefix of extractPosition,
1583/// extract a portion of the source of the insertion.
1584/// This method updates the internal state.
1585LogicalResult
1586ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1587 // TODO: Canonicalization for dynamic position not implemented yet.
1588 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1589 return failure();
1590
1591 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1592 if (!isContainedWithin(insertedPos, extractPosition))
1593 return failure();
1594 // Set leading dims to zero.
1595 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1596 // Drop extra leading dims.
1597 extractPosition.erase(extractPosition.begin(),
1598 extractPosition.begin() + insertedPos.size());
1599 extractedRank = extractPosition.size() - sentinels.size();
1600 // Case 3.a. early-exit fold (break and delegate to post-while path).
1601 res = nextInsertOp.getValueToStore();
1602 // Case 3.b. if internal transposition is present, canFold will be false.
1603 return success();
1604}
1605
1606/// Try to fold in place to extract(source, extractPosition) and return the
1607/// folded result. Return null if folding is not possible (e.g. due to an
1608/// internal transposition in the result).
1609Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1610 Value source) {
1611 // TODO: Canonicalization for dynamic position not implemented yet.
1612 if (extractOp.hasDynamicPosition())
1613 return Value();
1614
1615 // If we can't fold (either internal transposition, or nothing to fold), bail.
1616 bool nothingToFold = (source == extractOp.getSource());
1617 if (nothingToFold || !canFold())
1618 return Value();
1619
1620 // Otherwise, fold by updating the op inplace and return its result.
1621 OpBuilder b(extractOp.getContext());
1622 extractOp.setStaticPosition(
1623 ArrayRef(extractPosition).take_front(extractedRank));
1624 extractOp.getSourceMutable().assign(source);
1625 return extractOp.getResult();
1626}
1627
1628/// Iterate over producing insert and transpose ops until we find a fold.
1629Value ExtractFromInsertTransposeChainState::fold() {
1630 // TODO: Canonicalization for dynamic position not implemented yet.
1631 if (extractOp.hasDynamicPosition())
1632 return Value();
1633
1634 Value valueToExtractFrom = extractOp.getSource();
1635 updateStateForNextIteration(valueToExtractFrom);
1636 while (nextInsertOp || nextTransposeOp) {
1637 // Case 1. If we hit a transpose, just compose the map and iterate.
1638 // Invariant: insert + transpose do not change rank, we can always compose.
1639 if (succeeded(handleTransposeOp())) {
1640 valueToExtractFrom = nextTransposeOp.getVector();
1641 updateStateForNextIteration(valueToExtractFrom);
1642 continue;
1643 }
1644
1645 Value result;
1646 // Case 2: the position match exactly.
1647 if (succeeded(handleInsertOpWithMatchingPos(result)))
1648 return result;
1649
1650 // Case 3: if the inserted position is a prefix of extractPosition, we can
1651 // just extract a portion of the source of the insert.
1652 if (succeeded(handleInsertOpWithPrefixPos(result)))
1653 return tryToFoldExtractOpInPlace(result);
1654
1655 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1656 // values. This is a more difficult case and we bail.
1657 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1658 if (isContainedWithin(extractPosition, insertedPos) ||
1659 intersectsWhereNonNegative(extractPosition, insertedPos))
1660 return Value();
1661
1662 // Case 5: No intersection, we forward the extract to insertOp.dest().
1663 valueToExtractFrom = nextInsertOp.getDest();
1664 updateStateForNextIteration(valueToExtractFrom);
1665 }
1666 // If after all this we can fold, go for it.
1667 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1668}
1669
1670/// Returns true if the operation has a 0-D vector type operand or result.
1672 auto hasZeroDimVectorType = [](Type type) -> bool {
1673 auto vecType = dyn_cast<VectorType>(type);
1674 return vecType && vecType.getRank() == 0;
1675 };
1676
1677 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1678 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1679}
1680
1681/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1682/// considered to be 'broadcastlike'.
1683static bool isBroadcastLike(Operation *op) {
1684 if (isa<BroadcastOp>(op))
1685 return true;
1686
1687 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1688 if (!shapeCast)
1689 return false;
1690
1691 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1692 // Checking that the destination shape has a prefix of 1s is not sufficient,
1693 // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1694 // is that the source shape is a suffix of the destination shape.
1695 VectorType srcType = shapeCast.getSourceVectorType();
1696 ArrayRef<int64_t> srcShape = srcType.getShape();
1697 uint64_t srcRank = srcType.getRank();
1698 ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1699 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1700}
1701
1702/// Fold extract(broadcast(X)) to either extract(X) or just X.
1703///
1704/// Example:
1705///
1706/// broadcast extract [1][2]
1707/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1708///
1709/// becomes
1710/// extract [1]
1711/// (3,4) -------------------------------------> (4)
1712///
1713///
1714/// The variable names used in this implementation correspond to the above
1715/// shapes as,
1716///
1717/// - (3, 4) is `input` shape.
1718/// - (2, 3, 4) is `broadcast` shape.
1719/// - (4) is `extract` shape.
1720///
1721/// This folding is possible when the suffix of `input` shape is the same as
1722/// `extract` shape.
1723static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1724
1725 Operation *defOp = extractOp.getSource().getDefiningOp();
1726 if (!defOp || !isBroadcastLike(defOp))
1727 return Value();
1728
1729 Value input = defOp->getOperand(0);
1730
1731 // Replace extract(broadcast(X)) with X
1732 if (extractOp.getType() == input.getType())
1733 return input;
1734
1735 // Get required types and ranks in the chain
1736 // input -> broadcast -> extract
1737 // (scalars are treated as rank-0).
1738 auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1739 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1740 unsigned inputRank = inputType ? inputType.getRank() : 0;
1741 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1742 unsigned extractRank = extractType ? extractType.getRank() : 0;
1743
1744 // Cannot do without the broadcast if overall the rank increases.
1745 if (extractRank > inputRank)
1746 return Value();
1747
1748 // The above condition guarantees that input is a vector.
1749 assert(inputType && "input must be a vector type because of previous checks");
1750 ArrayRef<int64_t> inputShape = inputType.getShape();
1751
1752 // In the case where there is a broadcast dimension in the suffix, it is not
1753 // possible to replace extract(broadcast(X)) with extract(X). Example:
1754 //
1755 // broadcast extract
1756 // (1) --------> (3,4) ------> (4)
1757 if (extractType &&
1758 extractType.getShape() != inputShape.take_back(extractRank))
1759 return Value();
1760
1761 // Replace extract(broadcast(X)) with extract(X).
1762 // First, determine the new extraction position.
1763 unsigned deltaOverall = inputRank - extractRank;
1764 unsigned deltaBroadcast = broadcastRank - inputRank;
1765 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1766 SmallVector<OpFoldResult> newPositions(deltaOverall);
1767 IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1768 for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1769 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1770 }
1771 auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1772 extractOp->setOperands(
1773 llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1774 extractOp.setStaticPosition(staticPos);
1775 return extractOp.getResult();
1776}
1777
1778/// Fold extractOp coming from ShuffleOp.
1779///
1780/// Example:
1781///
1782/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1783/// : vector<8xf32>, vector<8xf32>
1784/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1785/// ->
1786/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1787///
1788static Value foldExtractFromShuffle(ExtractOp extractOp) {
1789 // Dynamic positions are not folded as the resulting code would be more
1790 // complex than the input code.
1791 if (extractOp.hasDynamicPosition())
1792 return Value();
1793
1794 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1795 if (!shuffleOp)
1796 return Value();
1797
1798 // TODO: 0-D or multi-dimensional vectors not supported yet.
1799 if (shuffleOp.getResultVectorType().getRank() != 1)
1800 return Value();
1801
1802 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1803 auto shuffleMask = shuffleOp.getMask();
1804 int64_t extractIdx = extractOp.getStaticPosition()[0];
1805 int64_t shuffleIdx = shuffleMask[extractIdx];
1806
1807 // Find the shuffled vector to extract from based on the shuffle index.
1808 if (shuffleIdx < inputVecSize) {
1809 extractOp.setOperand(0, shuffleOp.getV1());
1810 extractOp.setStaticPosition({shuffleIdx});
1811 } else {
1812 extractOp.setOperand(0, shuffleOp.getV2());
1813 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1814 }
1815
1816 return extractOp.getResult();
1817}
1818
1819// Fold extractOp with source coming from ShapeCast op.
1820static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1821 // TODO: Canonicalization for dynamic position not implemented yet.
1822 if (extractOp.hasDynamicPosition())
1823 return Value();
1824
1825 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1826 if (!shapeCastOp)
1827 return Value();
1828
1829 // Get the nth dimension size starting from lowest dimension.
1830 auto getDimReverse = [](VectorType type, int64_t n) {
1831 return type.getShape().take_back(n + 1).front();
1832 };
1833 int64_t destinationRank =
1834 llvm::isa<VectorType>(extractOp.getType())
1835 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1836 : 0;
1837 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1838 return Value();
1839 if (destinationRank > 0) {
1840 auto destinationType =
1841 llvm::cast<VectorType>(extractOp.getResult().getType());
1842 for (int64_t i = 0; i < destinationRank; i++) {
1843 // The lowest dimension of the destination must match the lowest
1844 // dimension of the shapecast op source.
1845 // TODO: This case could be support in a canonicalization pattern.
1846 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1847 getDimReverse(destinationType, i))
1848 return Value();
1849 }
1850 }
1851 // Extract the strides associated with the extract op vector source. Then use
1852 // this to calculate a linearized position for the extract.
1853 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1854 std::reverse(extractedPos.begin(), extractedPos.end());
1856 int64_t stride = 1;
1857 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1858 strides.push_back(stride);
1859 stride *=
1860 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1861 }
1862
1863 int64_t position = linearize(extractedPos, strides);
1864 // Then extract the strides associated to the shapeCast op vector source and
1865 // delinearize the position using those strides.
1866 SmallVector<int64_t, 4> newStrides;
1867 int64_t numDimension =
1868 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1869 stride = 1;
1870 for (int64_t i = 0; i < numDimension; i++) {
1871 newStrides.push_back(stride);
1872 stride *=
1873 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1874 }
1875 std::reverse(newStrides.begin(), newStrides.end());
1876 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1877 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1878 OpBuilder b(extractOp.getContext());
1879 extractOp.setStaticPosition(newPosition);
1880 extractOp.setOperand(0, shapeCastOp.getSource());
1881 return extractOp.getResult();
1882}
1883
1884/// Fold an ExtractOp from ExtractStridedSliceOp.
1885static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1886 // TODO: Canonicalization for dynamic position not implemented yet.
1887 if (extractOp.hasDynamicPosition())
1888 return Value();
1889
1890 auto extractStridedSliceOp =
1891 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1892 if (!extractStridedSliceOp)
1893 return Value();
1894
1895 // 0-D vectors not supported.
1896 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1897 if (hasZeroDimVectors(extractStridedSliceOp))
1898 return Value();
1899
1900 // Return if 'extractStridedSliceOp' has non-unit strides.
1901 if (extractStridedSliceOp.hasNonUnitStrides())
1902 return Value();
1903
1904 // Trim offsets for dimensions fully extracted.
1905 auto sliceOffsets =
1906 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1907 while (!sliceOffsets.empty()) {
1908 size_t lastOffset = sliceOffsets.size() - 1;
1909 if (sliceOffsets.back() != 0 ||
1910 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1911 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1912 break;
1913 sliceOffsets.pop_back();
1914 }
1915 unsigned destinationRank = 0;
1916 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1917 destinationRank = vecType.getRank();
1918 // The dimensions of the result need to be untouched by the
1919 // extractStridedSlice op.
1920 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1921 sliceOffsets.size())
1922 return Value();
1923
1924 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1925 assert(extractedPos.size() >= sliceOffsets.size());
1926 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1927 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1928 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1929
1930 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1931 OpBuilder b(extractOp.getContext());
1932 extractOp.setStaticPosition(extractedPos);
1933 return extractOp.getResult();
1934}
1935
1936/// Fold extract_op fed from a chain of insertStridedSlice ops.
1937static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1938 // TODO: Canonicalization for dynamic position not implemented yet.
1939 if (extractOp.hasDynamicPosition())
1940 return Value();
1941
1942 int64_t destinationRank =
1943 llvm::isa<VectorType>(extractOp.getType())
1944 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1945 : 0;
1946 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1947 if (!insertOp)
1948 return Value();
1949
1950 // 0-D vectors not supported.
1951 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1952 if (hasZeroDimVectors(insertOp))
1953 return Value();
1954
1955 while (insertOp) {
1956 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1957 insertOp.getSourceVectorType().getRank();
1958 if (destinationRank > insertOp.getSourceVectorType().getRank())
1959 return Value();
1960 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1961 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1962
1963 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1964 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1965 }))
1966 return Value();
1967 bool disjoint = false;
1968 SmallVector<int64_t, 4> offsetDiffs;
1969 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1970 int64_t start = insertOffsets[dim];
1971 int64_t size =
1972 (dim < insertRankDiff)
1973 ? 1
1974 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1975 int64_t end = start + size;
1976 int64_t offset = extractOffsets[dim];
1977 // Check if the start of the extract offset is in the interval inserted.
1978 if (start <= offset && offset < end) {
1979 if (dim >= insertRankDiff)
1980 offsetDiffs.push_back(offset - start);
1981 continue;
1982 }
1983 disjoint = true;
1984 break;
1985 }
1986 // The extract element chunk overlap with the vector inserted.
1987 if (!disjoint) {
1988 // If any of the inner dimensions are only partially inserted we have a
1989 // partial overlap.
1990 int64_t srcRankDiff =
1991 insertOp.getSourceVectorType().getRank() - destinationRank;
1992 for (int64_t i = 0; i < destinationRank; i++) {
1993 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1994 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1995 insertRankDiff))
1996 return Value();
1997 }
1998 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1999 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2000 OpBuilder b(extractOp.getContext());
2001 extractOp.setStaticPosition(offsetDiffs);
2002 return extractOp.getResult();
2003 }
2004 // If the chunk extracted is disjoint from the chunk inserted, keep
2005 // looking in the insert chain.
2006 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2007 }
2008 return Value();
2009}
2010
2011/// Try to fold the extraction of a scalar from a vector defined by
2012/// vector.from_elements. E.g.:
2013///
2014/// %0 = vector.from_elements %a, %b : vector<2xf32>
2015/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2016/// ==> fold to %a
2017static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2018 // Dynamic extractions cannot be folded.
2019 if (extractOp.hasDynamicPosition())
2020 return {};
2021
2022 // Look for extract(from_elements).
2023 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2024 if (!fromElementsOp)
2025 return {};
2026
2027 // Scalable vectors are not supported.
2028 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2029 if (vecType.isScalable())
2030 return {};
2031
2032 // Only extractions of scalars are supported.
2033 int64_t rank = vecType.getRank();
2034 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2035 if (extractOp.getType() != vecType.getElementType())
2036 return {};
2037 assert(static_cast<int64_t>(indices.size()) == rank &&
2038 "unexpected number of indices");
2039
2040 // Compute flattened/linearized index and fold to operand.
2041 int flatIndex = 0;
2042 int stride = 1;
2043 for (int i = rank - 1; i >= 0; --i) {
2044 flatIndex += indices[i] * stride;
2045 stride *= vecType.getDimSize(i);
2046 }
2047 return fromElementsOp.getElements()[flatIndex];
2048}
2049
2050/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2051/// then fold it.
2052template <typename OpType, typename AdaptorType>
2053static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2054 SmallVectorImpl<Value> &operands) {
2055 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2056 OperandRange dynamicPosition = op.getDynamicPosition();
2057 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2059 if constexpr (std::is_same_v<OpType, ExtractOp>)
2060 vectorShape = op.getSourceVectorType().getShape();
2061 else
2062 vectorShape = op.getDestVectorType().getShape();
2063
2064 // If the dynamic operands is empty, it is returned directly.
2065 if (!dynamicPosition.size())
2066 return {};
2067
2068 // `index` is used to iterate over the `dynamicPosition`.
2069 unsigned index = 0;
2070
2071 // `opChange` is a flag. If it is true, it means to update `op` in place.
2072 bool opChange = false;
2073 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2074 if (ShapedType::isStatic(staticPosition[i]))
2075 continue;
2076 Attribute positionAttr = dynamicPositionAttr[index];
2077 Value position = dynamicPosition[index++];
2078 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2079 int64_t value = attr.getInt();
2080 // Do not fold if the value is out of bounds (-1 signifies a poison
2081 // value rather than OOB index).
2082 if (value >= -1 && value < vectorShape[i]) {
2083 staticPosition[i] = attr.getInt();
2084 opChange = true;
2085 continue;
2086 }
2087 }
2088 operands.push_back(position);
2089 }
2090
2091 if (opChange) {
2092 op.setStaticPosition(staticPosition);
2093 op.getOperation()->setOperands(operands);
2094 // Return the original result to indicate an in-place folding happened.
2095 return op.getResult();
2096 }
2097 return {};
2098}
2099
2100/// Fold an insert or extract operation into an poison value when a poison index
2101/// is found at any dimension of the static position.
2103 ArrayRef<int64_t> staticPos,
2104 int64_t poisonVal) {
2105 if (!is_contained(staticPos, poisonVal))
2106 return {};
2107
2108 return ub::PoisonAttr::get(context);
2109}
2110
2111/// Fold a vector extract from is a poison source.
2113 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2114 return srcAttr;
2115
2116 return {};
2117}
2118
2119/// Fold a vector extract extracting from a DenseElementsAttr.
2121 Attribute srcAttr) {
2122 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2123 if (!denseAttr) {
2124 return {};
2125 }
2126
2127 if (denseAttr.isSplat()) {
2128 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2129 if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2130 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2131 return newAttr;
2132 }
2133
2134 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2135 if (vecTy.isScalable())
2136 return {};
2137
2138 if (extractOp.hasDynamicPosition()) {
2139 return {};
2140 }
2141
2142 // Materializing subsets of a large constant array can generally lead to
2143 // explosion in IR size because of different combination of subsets that
2144 // can exist. However, vector.extract is a restricted form of subset
2145 // extract where you can only extract non-overlapping (or the same) subset for
2146 // a given rank of the subset. Because of this property, the IR size can only
2147 // increase at most by `rank * size(array)` from a single constant array being
2148 // extracted by multiple extracts.
2149
2150 // Calculate the linearized position of the continuous chunk of elements to
2151 // extract.
2152 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2153 copy(extractOp.getStaticPosition(), completePositions.begin());
2154 int64_t startPos =
2155 linearize(completePositions, computeStrides(vecTy.getShape()));
2156 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2157
2158 TypedAttr newAttr;
2159 if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2160 SmallVector<Attribute> elementValues(
2161 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2162 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2163 } else {
2164 newAttr = *denseValuesBegin;
2165 }
2166
2167 return newAttr;
2168}
2169
2170OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2171 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2172 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2173 // mismatch).
2174 if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2175 return getSource();
2176 if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2177 return res;
2178 // Fold `arith.constant` indices into the `vector.extract` operation.
2179 // Do not stop here as this fold may enable subsequent folds that require
2180 // constant indices.
2181 SmallVector<Value> operands = {getSource()};
2182 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2183
2184 if (auto res = foldPoisonIndexInsertExtractOp(
2185 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2186 return res;
2187 if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2188 return res;
2189 if (succeeded(foldExtractOpFromExtractChain(*this)))
2190 return getResult();
2191 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2192 return res;
2193 if (auto res = foldExtractFromBroadcast(*this))
2194 return res;
2195 if (auto res = foldExtractFromShuffle(*this))
2196 return res;
2197 if (auto res = foldExtractFromShapeCast(*this))
2198 return res;
2199 if (auto val = foldExtractFromExtractStrided(*this))
2200 return val;
2201 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2202 return val;
2203 if (auto val = foldScalarExtractFromFromElements(*this))
2204 return val;
2205
2206 return inplaceFolded;
2207}
2208
2209namespace {
2210
2211// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2212class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2213public:
2214 using Base::Base;
2215
2216 LogicalResult matchAndRewrite(ExtractOp extractOp,
2217 PatternRewriter &rewriter) const override {
2218
2219 Operation *defOp = extractOp.getSource().getDefiningOp();
2220 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2221 if (!defOp || !isBroadcastLike(defOp) || !outType)
2222 return failure();
2223
2224 Value source = defOp->getOperand(0);
2225 if (isBroadcastableTo(source.getType(), outType) !=
2226 BroadcastableToResult::Success)
2227 return failure();
2228
2229 rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2230 return success();
2231 }
2232};
2233
2234// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2235class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2236public:
2237 using Base::Base;
2238
2239 LogicalResult matchAndRewrite(ExtractOp extractOp,
2240 PatternRewriter &rewriter) const override {
2241 auto createMaskOp =
2242 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2243 if (!createMaskOp)
2244 return failure();
2245
2246 VectorType extractedMaskType =
2247 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2248
2249 if (!extractedMaskType)
2250 return failure();
2251
2252 auto maskOperands = createMaskOp.getOperands();
2253 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2254 VectorType maskType = createMaskOp.getVectorType();
2255
2256 bool containsUnknownDims = false;
2257 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2258
2259 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2260 dimIdx++) {
2261 int64_t pos = extractOpPos[dimIdx];
2262 Value operand = maskOperands[dimIdx];
2263 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2264 if (!constantOp) {
2265 // Bounds of this dim unknown.
2266 containsUnknownDims = true;
2267 continue;
2268 }
2269
2270 int64_t createMaskBound =
2271 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2272
2273 if (pos != ShapedType::kDynamic) {
2274 // If any position is outside the range from the `create_mask`, then the
2275 // extracted mask will be all-false.
2276 allFalse |= pos >= createMaskBound;
2277 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2278 // This dim is not all-true and since this is a dynamic index we don't
2279 // know if the extraction is within the true or false region.
2280 // Note: Zero dims have already handled via getMaskFormat().
2281 containsUnknownDims = true;
2282 }
2283 }
2284
2285 if (allFalse) {
2286 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2287 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2288 } else if (!containsUnknownDims) {
2289 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2290 extractOp, extractedMaskType,
2291 maskOperands.drop_front(extractOpPos.size()));
2292 } else {
2293 return failure();
2294 }
2295 return success();
2296 }
2297};
2298
2299// Folds extract(shape_cast(..)) into shape_cast when the total element count
2300// does not change.
2301LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2302 PatternRewriter &rewriter) {
2303 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2304 if (!castOp)
2305 return failure();
2306
2307 VectorType sourceType = castOp.getSourceVectorType();
2308 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2309 if (!targetType)
2310 return failure();
2311
2312 if (sourceType.getNumElements() != targetType.getNumElements())
2313 return failure();
2314
2315 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2316 castOp.getSource());
2317 return success();
2318}
2319
2320/// Try to canonicalize the extraction of a subvector from a vector defined by
2321/// vector.from_elements. E.g.:
2322///
2323/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2324/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2325/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2326LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2327 PatternRewriter &rewriter) {
2328 // Dynamic positions are not supported.
2329 if (extractOp.hasDynamicPosition())
2330 return failure();
2331
2332 // Scalar extracts are handled by the folder.
2333 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2334 if (!resultType)
2335 return failure();
2336
2337 // Look for extracts from a from_elements op.
2338 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2339 if (!fromElementsOp)
2340 return failure();
2341 VectorType inputType = fromElementsOp.getType();
2342
2343 // Scalable vectors are not supported.
2344 if (resultType.isScalable() || inputType.isScalable())
2345 return failure();
2346
2347 // Compute the position of first extracted element and flatten/linearize the
2348 // position.
2349 SmallVector<int64_t> firstElementPos =
2350 llvm::to_vector(extractOp.getStaticPosition());
2351 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2352 int flatIndex = 0;
2353 int stride = 1;
2354 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2355 flatIndex += firstElementPos[i] * stride;
2356 stride *= inputType.getDimSize(i);
2357 }
2358
2359 // Replace the op with a smaller from_elements op.
2360 rewriter.replaceOpWithNewOp<FromElementsOp>(
2361 extractOp, resultType,
2362 fromElementsOp.getElements().slice(flatIndex,
2363 resultType.getNumElements()));
2364 return success();
2365}
2366
2367} // namespace
2368
2369void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2370 MLIRContext *context) {
2371 results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2372 results.add(foldExtractFromShapeCastToShapeCast);
2373 results.add(foldExtractFromFromElements);
2374}
2375
2377 SmallVectorImpl<int64_t> &results) {
2378 for (auto attr : arrayAttr)
2379 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2380}
2381
2382//===----------------------------------------------------------------------===//
2383// FmaOp
2384//===----------------------------------------------------------------------===//
2385
2386std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2387 return llvm::to_vector<4>(getVectorType().getShape());
2388}
2389
2390//===----------------------------------------------------------------------===//
2391// ToElementsOp
2392//===----------------------------------------------------------------------===//
2393
2394/// Returns true if all the `operands` are defined by `defOp`.
2395/// Otherwise, returns false.
2396static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2397 if (operands.empty())
2398 return false;
2399
2400 return llvm::all_of(operands, [&](Value operand) {
2401 Operation *currentDef = operand.getDefiningOp();
2402 return currentDef == defOp;
2403 });
2404}
2405
2406/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2407/// (%e0, %e1, ...). For example:
2408///
2409/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2410/// %1:3 = vector.to_elements %0 : vector<3xf32>
2411/// user_op %1#0, %1#1, %1#2
2412///
2413/// becomes:
2414///
2415/// user_op %a, %b, %c
2416///
2417static LogicalResult
2418foldToElementsFromElements(ToElementsOp toElementsOp,
2420 auto fromElementsOp =
2421 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2422 if (!fromElementsOp)
2423 return failure();
2424
2425 llvm::append_range(results, fromElementsOp.getElements());
2426 return success();
2427}
2428
2429/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2430///
2431/// Example:
2432/// %b = vector.broadcast %x : i32 to vector<3xf32>
2433/// %e:3 = vector.to_elements %b : vector<3xf32>
2434/// user_op %e#0, %e#1, %e#2
2435/// becomes:
2436/// user_op %x, %x, %x
2437///
2438/// The vector source case is handled by a canonicalization pattern.
2439static LogicalResult
2440foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2442 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2443 if (!bcastOp)
2444 return failure();
2445 // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2446 if (isa<VectorType>(bcastOp.getSource().getType()))
2447 return failure();
2448
2449 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2450
2451 Value scalar = bcastOp.getSource();
2452 results.assign(resultVecType.getNumElements(), scalar);
2453 return success();
2454}
2455
2456LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2457 SmallVectorImpl<OpFoldResult> &results) {
2458 if (succeeded(foldToElementsFromElements(*this, results)))
2459 return success();
2460 return foldToElementsOfBroadcast(*this, results);
2461}
2462
2463LogicalResult
2464ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2465 ToElementsOp::Adaptor adaptor,
2466 SmallVectorImpl<Type> &inferredReturnTypes) {
2467 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2468 Type elType = vecType.getElementType();
2469 inferredReturnTypes.append(vecType.getNumElements(), elType);
2470 return success();
2471}
2472
2473/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2474/// vector.
2475/// - Build `vector.to_elements %v` and remap each destination element to the
2476/// corresponding source element using broadcast rules (match or 1 →
2477/// replicate).
2478///
2479/// Example:
2480/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2481/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2482/// becomes:
2483/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2484/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2485/// // %src_elems#1, %src_elems#0, %src_elems#1
2486struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2487 using Base::Base;
2488
2489 LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2490 PatternRewriter &rewriter) const override {
2491 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2492 if (!bcastOp)
2493 return failure();
2494
2495 // Only handle broadcasts from a vector source here.
2496 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2497 if (!srcType)
2498 return failure();
2499
2500 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2501
2502 ArrayRef<int64_t> dstShape = dstType.getShape();
2503 ArrayRef<int64_t> srcShape = srcType.getShape();
2504
2505 int64_t dstRank = dstShape.size();
2506 int64_t srcRank = srcShape.size();
2507
2508 // Create elements for the broadcast source vector.
2509 auto srcElems = vector::ToElementsOp::create(
2510 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2511
2512 int64_t dstCount = llvm::product_of(dstShape);
2513
2514 SmallVector<Value> replacements;
2515 replacements.reserve(dstCount);
2516
2517 // For each element of the destination, determine which element of the
2518 // source should be used. We walk all destination positions using a single
2519 // counter, decode it into per-dimension indices, then build the matching
2520 // source position: use the same index where sizes match, and use 0 where
2521 // the source size is 1 (replication). This mapping is needed so we can
2522 // replace each result of to_elements with the corresponding element from
2523 // the broadcast source.
2524 // Inner-dimension stretch example:
2525 // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2526 // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2527 // becomes:
2528 // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2529 // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2530 // // %src_elems#1, %src_elems#0, %src_elems#1,
2531 // // %src_elems#2, %src_elems#3, %src_elems#2,
2532 // // %src_elems#3, %src_elems#2, %src_elems#3
2533
2534 // Row-major strides for the destination shape.
2535 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2536 // Row-major strides for the source shape.
2537 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2538 SmallVector<int64_t> dstIdx(dstRank);
2539 SmallVector<int64_t> srcIdx(srcRank);
2540 for (int64_t lin = 0; lin < dstCount; ++lin) {
2541 // Convert linear destination index to per-dimension indices.
2542 dstIdx = delinearize(lin, dstStrides);
2543 for (int64_t k = 0; k < srcRank; ++k)
2544 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2545 // Convert per-dimension source indices back to a linear index.
2546 int64_t srcLin = linearize(srcIdx, srcStrides);
2547 replacements.push_back(srcElems.getResult(srcLin));
2548 }
2549
2550 rewriter.replaceOp(toElementsOp, replacements);
2551 return success();
2552 }
2553};
2554
2555void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2556 MLIRContext *context) {
2557 results.add<ToElementsOfBroadcast>(context);
2558}
2559
2560//===----------------------------------------------------------------------===//
2561// FromElementsOp
2562//===----------------------------------------------------------------------===//
2563
2564/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2565///
2566/// Case #1: Input and output vectors are the same.
2567///
2568/// %0:3 = vector.to_elements %a : vector<3xf32>
2569/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2570/// user_op %1
2571///
2572/// becomes:
2573///
2574/// user_op %a
2575///
2576static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2577 OperandRange fromElemsOperands = fromElementsOp.getElements();
2578 if (fromElemsOperands.empty())
2579 return {};
2580
2581 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2582 if (!toElementsOp)
2583 return {};
2584
2585 if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2586 return {};
2587
2588 // Case #1: Input and output vectors are the same. Forward the input vector.
2589 Value toElementsInput = toElementsOp.getSource();
2590 if (fromElementsOp.getType() == toElementsInput.getType() &&
2591 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2592 return toElementsInput;
2593 }
2594
2595 // TODO: Support cases with different input and output shapes and different
2596 // number of elements.
2597
2598 return {};
2599}
2600
2601/// Fold vector.from_elements to a constant when all operands are constants.
2602/// Example:
2603/// %c1 = arith.constant 1 : i32
2604/// %c2 = arith.constant 2 : i32
2605/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2606/// =>
2607/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2608///
2609static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2610 ArrayRef<Attribute> elements) {
2611 // Check for null or poison attributes before any processing.
2612 if (llvm::any_of(elements, [](Attribute attr) {
2613 return !attr || isa<ub::PoisonAttrInterface>(attr);
2614 }))
2615 return {};
2616
2617 // DenseElementsAttr only supports int/index/float/complex types.
2618 auto destVecType = fromElementsOp.getDest().getType();
2619 auto destEltType = destVecType.getElementType();
2620 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2621 return {};
2622
2623 // Constant attributes might have a different type than the return type.
2624 // Convert them before creating the dense elements attribute.
2625 auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2626 return convertNumericAttr(attr, destEltType);
2627 });
2628
2629 return DenseElementsAttr::get(destVecType, convertedElements);
2630}
2631
2632OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2633 if (auto res = foldFromElementsToElements(*this))
2634 return res;
2635 if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2636 return res;
2637
2638 return {};
2639}
2640
2641/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2642/// same. Example:
2643/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2644/// =>
2645/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2646static LogicalResult
2647rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2648 PatternRewriter &rewriter) {
2649 if (!llvm::all_equal(fromElementsOp.getElements()))
2650 return failure();
2651 rewriter.replaceOpWithNewOp<BroadcastOp>(
2652 fromElementsOp, fromElementsOp.getType(),
2653 fromElementsOp.getElements().front());
2654 return success();
2655}
2656
2657/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2658/// on a single extract. Example:
2659/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2660/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2661/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2662///
2663/// becomes
2664/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2665/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2666///
2667/// The requirements for this to be valid are
2668///
2669/// i) The elements are extracted from the same vector (%source).
2670///
2671/// ii) The elements form a suffix of %source. Specifically, the number
2672/// of elements is the same as the product of the last N dimension sizes
2673/// of %source, for some N.
2674///
2675/// iii) The elements are extracted contiguously in ascending order.
2676
2677class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2678
2679 using Base::Base;
2680
2681 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2682 PatternRewriter &rewriter) const override {
2683
2684 // Handled by `rewriteFromElementsAsBroadcast`.
2685 if (fromElements.getType().getNumElements() == 1)
2686 return failure();
2687
2688 // The common source that all elements are extracted from, if one exists.
2690 // The position of the combined extract operation, if one is created.
2691 ArrayRef<int64_t> combinedPosition;
2692 // The expected index of extraction of the current element in the loop, if
2693 // elements are extracted contiguously in ascending order.
2694 SmallVector<int64_t> expectedPosition;
2695
2696 for (auto [insertIndex, element] :
2697 llvm::enumerate(fromElements.getElements())) {
2698
2699 // Check that the element is from a vector.extract operation.
2700 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2701 if (!extractOp) {
2702 return rewriter.notifyMatchFailure(fromElements,
2703 "element not from vector.extract");
2704 }
2705
2706 // Check condition (i) by checking that all elements have the same source
2707 // as the first element.
2708 if (insertIndex == 0) {
2709 source = extractOp.getSource();
2710 } else if (extractOp.getSource() != source) {
2711 return rewriter.notifyMatchFailure(fromElements,
2712 "element from different vector");
2713 }
2714
2715 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2716 int64_t rank = position.size();
2717 assert(rank == source.getType().getRank() &&
2718 "scalar extract must have full rank position");
2719
2720 // Check condition (ii) by checking that the position that the first
2721 // element is extracted from has sufficient trailing 0s. For example, in
2722 //
2723 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2724 // [...]
2725 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2726 //
2727 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2728 // elements, which is the number of elements of %n, so this is valid.
2729 if (insertIndex == 0) {
2730 const int64_t numElms = fromElements.getType().getNumElements();
2731 int64_t numSuffixElms = 1;
2732 int64_t index = rank;
2733 while (index > 0 && position[index - 1] == 0 &&
2734 numSuffixElms < numElms) {
2735 numSuffixElms *= source.getType().getDimSize(index - 1);
2736 --index;
2737 }
2738 if (numSuffixElms != numElms) {
2739 return rewriter.notifyMatchFailure(
2740 fromElements, "elements do not form a suffix of source");
2741 }
2742 expectedPosition = llvm::to_vector(position);
2743 combinedPosition = position.drop_back(rank - index);
2744 }
2745
2746 // Check condition (iii).
2747 else if (expectedPosition != position) {
2748 return rewriter.notifyMatchFailure(
2749 fromElements, "elements not in ascending order (static order)");
2750 }
2751 increment(expectedPosition, source.getType().getShape());
2752 }
2753
2754 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2755 fromElements.getLoc(), source, combinedPosition);
2756
2757 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2758 fromElements, fromElements.getType(), extracted);
2759
2760 return success();
2761 }
2762
2763 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2764 static void increment(MutableArrayRef<int64_t> indices,
2766 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2767 indices[dim] += 1;
2768 if (indices[dim] < shape[dim])
2769 break;
2770 indices[dim] = 0;
2771 }
2772 }
2773};
2774
2775void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2776 MLIRContext *context) {
2778 results.add<FromElementsToShapeCast>(context);
2779}
2780
2781//===----------------------------------------------------------------------===//
2782// BroadcastOp
2783//===----------------------------------------------------------------------===//
2784
2785void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2786 SetIntRangeFn setResultRanges) {
2787 setResultRanges(getResult(), argRanges.front());
2788}
2789
2790std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2791 return llvm::to_vector<4>(getResultVectorType().getShape());
2792}
2793
2794/// Return the dimensions of the result vector that were formerly ones in the
2795/// source tensor and thus correspond to "dim-1" broadcasting.
2796static llvm::SetVector<int64_t>
2798 ArrayRef<int64_t> dstShape) {
2799 int64_t rankDiff = dstShape.size() - srcShape.size();
2800 int64_t dstDim = rankDiff;
2802 for (auto [s1, s2] :
2803 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2804 if (s1 != s2) {
2805 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2806 res.insert(dstDim);
2807 }
2808 ++dstDim;
2809 }
2810 return res;
2811}
2812
2813llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2814 // Scalar broadcast is without any unit dim broadcast.
2815 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2816 if (!srcVectorType)
2817 return {};
2818 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2819 getResultVectorType().getShape());
2820}
2821
2822/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2823/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2824/// This requires (and asserts) that the broadcast is free of "dim-1"
2825/// broadcasting.
2826/// Since vector.broadcast only allows expanding leading dimensions, an extra
2827/// vector.transpose may be inserted to make the broadcast possible.
2828/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2829/// the helper will assert. This means:
2830/// 1. `dstShape` must not be empty.
2831/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2832/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2833// must match the `value` shape.
2834Value BroadcastOp::createOrFoldBroadcastOp(
2835 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2836 const llvm::SetVector<int64_t> &broadcastedDims) {
2837 assert(!dstShape.empty() && "unexpected empty dst shape");
2838
2839 // Well-formedness check.
2840 SmallVector<int64_t> checkShape;
2841 for (int i = 0, e = dstShape.size(); i < e; ++i) {
2842 if (broadcastedDims.contains(i))
2843 continue;
2844 checkShape.push_back(dstShape[i]);
2845 }
2846 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2847 "ill-formed broadcastedDims contains values not confined to "
2848 "destVectorShape");
2849
2850 Location loc = value.getLoc();
2851 Type elementType = getElementTypeOrSelf(value.getType());
2852 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2853 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2854
2855 // Step 2. If scalar -> dstShape broadcast, just do it.
2856 if (!srcVectorType) {
2857 assert(checkShape.empty() &&
2858 "ill-formed createOrFoldBroadcastOp arguments");
2859 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2860 }
2861
2862 assert(srcVectorType.getShape().equals(checkShape) &&
2863 "ill-formed createOrFoldBroadcastOp arguments");
2864
2865 // Step 3. Since vector.broadcast only allows creating leading dims,
2866 // vector -> dstShape broadcast may require a transpose.
2867 // Traverse the dims in order and construct:
2868 // 1. The leading entries of the broadcastShape that is guaranteed to be
2869 // achievable by a simple broadcast.
2870 // 2. The induced permutation for the subsequent vector.transpose that will
2871 // bring us from `broadcastShape` back to he desired `dstShape`.
2872 // If the induced permutation is not the identity, create a vector.transpose.
2873 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2874 broadcastShape.reserve(dstShape.size());
2875 // Consider the example:
2876 // srcShape = 2x4
2877 // dstShape = 1x2x3x4x5
2878 // broadcastedDims = [0, 2, 4]
2879 //
2880 // We want to build:
2881 // broadcastShape = 1x3x5x2x4
2882 // permutation = [0, 2, 4, 1, 3]
2883 // ---V--- -----V-----
2884 // leading broadcast part src shape part
2885 //
2886 // Note that the trailing dims of broadcastShape are exactly the srcShape
2887 // by construction.
2888 // nextSrcShapeDim is used to keep track of where in the permutation the
2889 // "src shape part" occurs.
2890 int64_t nextSrcShapeDim = broadcastedDims.size();
2891 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2892 if (broadcastedDims.contains(i)) {
2893 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2894 // bring it to the head of the broadcastShape.
2895 // It will need to be permuted back from `broadcastShape.size() - 1` into
2896 // position `i`.
2897 broadcastShape.push_back(dstShape[i]);
2898 permutation[i] = broadcastShape.size() - 1;
2899 } else {
2900 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2901 // shape and needs to be permuted into position `i`.
2902 // Don't touch `broadcastShape` here, the whole srcShape will be
2903 // appended after.
2904 permutation[i] = nextSrcShapeDim++;
2905 }
2906 }
2907 // 3.c. Append the srcShape.
2908 llvm::append_range(broadcastShape, srcVectorType.getShape());
2909
2910 // Ensure there are no "dim-1" broadcasts.
2911 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2912 .empty() &&
2913 "unexpected \"dim-1\" broadcast");
2914
2915 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2916 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2917 vector::BroadcastableToResult::Success &&
2918 "must be broadcastable");
2919 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2920 // Step 4. If we find any dimension that indeed needs to be permuted,
2921 // immediately return a new vector.transpose.
2922 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2923 if (permutation[i] != i)
2924 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2925 // Otherwise return res.
2926 return res;
2927}
2928
2930 Type srcType, VectorType dstVectorType,
2931 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2932 // Broadcast scalar to vector of the same element type.
2933 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2934 srcType == getElementTypeOrSelf(dstVectorType))
2936 // From now on, only vectors broadcast.
2937 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2938 if (!srcVectorType)
2940
2941 int64_t srcRank = srcVectorType.getRank();
2942 int64_t dstRank = dstVectorType.getRank();
2943 if (srcRank > dstRank)
2945 // Source has an exact match or singleton value for all trailing dimensions
2946 // (all leading dimensions are simply duplicated).
2947 int64_t lead = dstRank - srcRank;
2948 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2949 // Have mismatching dims (in the sense of vector.broadcast semantics) been
2950 // encountered?
2951 bool foundMismatchingDims = false;
2952
2953 // Check fixed-width dims.
2954 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2955 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2956 if (srcDim != 1 && srcDim != dstDim)
2957 foundMismatchingDims = true;
2958
2959 // Check scalable flags.
2960 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2961 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2962 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2963 // 1 -> [N] is fine, everything else should be rejected when mixing
2964 // fixed-width and scalable dims
2965 (srcDimScalableFlag != dstDimScalableFlag &&
2966 (srcDim != 1 || srcDimScalableFlag)))
2967 foundMismatchingDims = true;
2968
2969 if (foundMismatchingDims) {
2970 if (mismatchingDims != nullptr) {
2971 mismatchingDims->first.dim = srcDim;
2972 mismatchingDims->first.isScalable = srcDimScalableFlag;
2973
2974 mismatchingDims->second.dim = dstDim;
2975 mismatchingDims->second.isScalable = dstDimScalableFlag;
2976 }
2978 }
2979 }
2980
2982}
2983
2984LogicalResult BroadcastOp::verify() {
2985 std::pair<VectorDim, VectorDim> mismatchingDims;
2987 getSourceType(), getResultVectorType(), &mismatchingDims);
2989 return success();
2991 return emitOpError("source rank higher than destination rank");
2993 return emitOpError("dimension mismatch (")
2994 << (mismatchingDims.first.isScalable ? "[" : "")
2995 << mismatchingDims.first.dim
2996 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2997 << (mismatchingDims.second.isScalable ? "[" : "")
2998 << mismatchingDims.second.dim
2999 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
3000 }
3002 return emitOpError("source type is not a vector");
3003 llvm_unreachable("unexpected vector.broadcast op error");
3004}
3005
3006// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
3007// with broadcast's result type and shape_cast only adds or removes ones in the
3008// leading dimensions.
3009static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
3010 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3011 if (!srcShapeCast)
3012 return failure();
3013
3014 VectorType srcType = srcShapeCast.getSourceVectorType();
3015 VectorType destType = broadcastOp.getResultVectorType();
3016 // Check type compatibility.
3017 if (vector::isBroadcastableTo(srcType, destType) !=
3019 return failure();
3020
3021 ArrayRef<int64_t> srcShape = srcType.getShape();
3022 ArrayRef<int64_t> shapecastShape =
3023 srcShapeCast.getResultVectorType().getShape();
3024 // Trailing dimensions should be the same if shape_cast only alters the
3025 // leading dimensions.
3026 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3027 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3028 shapecastShape.take_back(numTrailingDims)))
3029 return failure();
3030
3031 assert(all_of(srcShape.drop_back(numTrailingDims),
3032 [](int64_t E) { return E == 1; }) &&
3033 all_of(shapecastShape.drop_back(numTrailingDims),
3034 [](int64_t E) { return E == 1; }) &&
3035 "ill-formed shape_cast");
3036
3037 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3038 return success();
3039}
3040
3041OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3042 if (getSourceType() == getResultVectorType())
3043 return getSource();
3044 if (succeeded(foldBroadcastOfShapeCast(*this)))
3045 return getResult();
3046
3047 if (!adaptor.getSource())
3048 return {};
3049 auto vectorType = getResultVectorType();
3050 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3051 if (vectorType.getElementType() != attr.getType())
3052 return {};
3053 return DenseElementsAttr::get(vectorType, attr);
3054 }
3055 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3056 if (vectorType.getElementType() != attr.getType())
3057 return {};
3058 return DenseElementsAttr::get(vectorType, attr);
3059 }
3060 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3061 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3062 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
3063 return ub::PoisonAttr::get(getContext());
3064 return {};
3065}
3066
3067namespace {
3068
3069// Fold broadcast1(broadcast2(x)) into broadcast1(x).
3070struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3071 using Base::Base;
3072
3073 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3074 PatternRewriter &rewriter) const override {
3075 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3076 if (!srcBroadcast)
3077 return failure();
3078 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3079 broadcastOp.getResultVectorType(),
3080 srcBroadcast.getSource());
3081 return success();
3082 }
3083};
3084} // namespace
3085
3086void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3087 MLIRContext *context) {
3088 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
3089 // calling `populateCastAwayVectorLeadingOneDimPatterns`
3090 results.add<BroadcastFolder>(context);
3091}
3092
3093//===----------------------------------------------------------------------===//
3094// ShuffleOp
3095//===----------------------------------------------------------------------===//
3096
3097LogicalResult ShuffleOp::verify() {
3098 VectorType resultType = getResultVectorType();
3099 VectorType v1Type = getV1VectorType();
3100 VectorType v2Type = getV2VectorType();
3101 // Verify ranks.
3102 int64_t resRank = resultType.getRank();
3103 int64_t v1Rank = v1Type.getRank();
3104 int64_t v2Rank = v2Type.getRank();
3105 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3106 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3107 if (!wellFormed0DCase && !wellFormedNDCase)
3108 return emitOpError("rank mismatch");
3109
3110 // Verify all but leading dimension sizes.
3111 for (int64_t r = 1; r < v1Rank; ++r) {
3112 int64_t resDim = resultType.getDimSize(r);
3113 int64_t v1Dim = v1Type.getDimSize(r);
3114 int64_t v2Dim = v2Type.getDimSize(r);
3115 if (resDim != v1Dim || v1Dim != v2Dim)
3116 return emitOpError("dimension mismatch");
3117 }
3118 // Verify mask length.
3119 ArrayRef<int64_t> mask = getMask();
3120 int64_t maskLength = mask.size();
3121 if (maskLength <= 0)
3122 return emitOpError("invalid mask length");
3123 if (maskLength != resultType.getDimSize(0))
3124 return emitOpError("mask length mismatch");
3125 // Verify all indices.
3126 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3127 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3128 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3129 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3130 return emitOpError("mask index #") << (idx + 1) << " out of range";
3131 }
3132 return success();
3133}
3134
3135LogicalResult
3136ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
3137 ShuffleOp::Adaptor adaptor,
3138 SmallVectorImpl<Type> &inferredReturnTypes) {
3139 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
3140 auto v1Rank = v1Type.getRank();
3141 // Construct resulting type: leading dimension matches mask
3142 // length, all trailing dimensions match the operands.
3143 SmallVector<int64_t, 4> shape;
3144 shape.reserve(v1Rank);
3145 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3146 // In the 0-D case there is no trailing shape to append.
3147 if (v1Rank > 0)
3148 llvm::append_range(shape, v1Type.getShape().drop_front());
3149 inferredReturnTypes.push_back(
3150 VectorType::get(shape, v1Type.getElementType()));
3151 return success();
3152}
3153
3154template <typename T>
3155static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3156 T expected = begin;
3157 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3158 return value == expected++;
3159 });
3160}
3161
3162OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3163 auto v1Type = getV1VectorType();
3164 auto v2Type = getV2VectorType();
3165
3166 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3167 "Vector shuffle does not support scalable vectors");
3168
3169 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3170 // but must be a canonicalization into a vector.broadcast.
3171 if (v1Type.getRank() == 0)
3172 return {};
3173
3174 // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3175 auto mask = getMask();
3176 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3177 return getV1();
3178 // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3179 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3180 return getV2();
3181
3182 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3183 if (!v1Attr || !v2Attr)
3184 return {};
3185
3186 // Fold shuffle poison, poison -> poison.
3187 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3188 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3189 if (isV1Poison && isV2Poison)
3190 return ub::PoisonAttr::get(getContext());
3191
3192 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
3193 // manipulation.
3194 if (v1Type.getRank() != 1)
3195 return {};
3196
3197 // Poison input attributes need special handling as they are not
3198 // DenseElementsAttr. If an index is poison, we select the first element of
3199 // the first non-poison input.
3200 SmallVector<Attribute> v1Elements, v2Elements;
3201 Attribute poisonElement;
3202 if (!isV2Poison) {
3203 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3204 if (!v2DenseAttr)
3205 return {};
3206 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3207 poisonElement = v2Elements[0];
3208 }
3209 if (!isV1Poison) {
3210 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3211 if (!v1DenseAttr)
3212 return {};
3213 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3214 poisonElement = v1Elements[0];
3215 }
3216
3217 SmallVector<Attribute> results;
3218 int64_t v1Size = v1Type.getDimSize(0);
3219 for (int64_t maskIdx : mask) {
3220 Attribute indexedElm;
3221 // TODO: Return a partial poison vector when supported by the UB dialect.
3222 if (maskIdx == ShuffleOp::kPoisonIndex) {
3223 indexedElm = poisonElement;
3224 } else {
3225 if (maskIdx < v1Size)
3226 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3227 else
3228 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3229 }
3230
3231 results.push_back(indexedElm);
3232 }
3233
3234 return DenseElementsAttr::get(getResultVectorType(), results);
3235}
3236
3237namespace {
3238
3239// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3240// to a broadcast.
3241struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3242 using Base::Base;
3243
3244 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3245 PatternRewriter &rewriter) const override {
3246 VectorType v1VectorType = shuffleOp.getV1VectorType();
3247 ArrayRef<int64_t> mask = shuffleOp.getMask();
3248 if (v1VectorType.getRank() > 0)
3249 return failure();
3250 if (mask.size() != 1)
3251 return failure();
3252 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3253 if (mask[0] == 0)
3254 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3255 shuffleOp.getV1());
3256 else
3257 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3258 shuffleOp.getV2());
3259 return success();
3260 }
3261};
3262
3263/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3264/// vector.broadcast with a scalar operand, return the scalar value that is
3265/// splatted. Otherwise return null.
3266///
3267/// Example:
3268///
3269/// scalar_source --> vector.broadcast --> value - return scalar_source
3270static Value getScalarSplatSource(Value value) {
3271 // Block argument:
3272 Operation *defOp = value.getDefiningOp();
3273 if (!defOp)
3274 return {};
3275
3276 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3277
3278 // Not broadcast (and not splat):
3279 if (!broadcast)
3280 return {};
3281
3282 // Broadcast of a vector:
3283 if (isa<VectorType>(broadcast.getSourceType()))
3284 return {};
3285
3286 // Broadcast of a scalar:
3287 return broadcast.getSource();
3288}
3289
3290/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3291class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3292public:
3293 using Base::Base;
3294
3295 LogicalResult matchAndRewrite(ShuffleOp op,
3296 PatternRewriter &rewriter) const override {
3297 Value splat = getScalarSplatSource(op.getV1());
3298 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3299 return failure();
3300
3301 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3302 return success();
3303 }
3304};
3305
3306/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3307/// vector.interleave.
3308class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3309public:
3310 using Base::Base;
3311
3312 LogicalResult matchAndRewrite(ShuffleOp op,
3313 PatternRewriter &rewriter) const override {
3314 VectorType resultType = op.getResultVectorType();
3315 if (resultType.isScalable())
3316 return rewriter.notifyMatchFailure(
3317 op, "ShuffleOp can't represent a scalable interleave");
3318
3319 if (resultType.getRank() != 1)
3320 return rewriter.notifyMatchFailure(
3321 op, "ShuffleOp can't represent an n-D interleave");
3322
3323 VectorType sourceType = op.getV1VectorType();
3324 if (sourceType != op.getV2VectorType() ||
3325 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3326 return rewriter.notifyMatchFailure(
3327 op, "ShuffleOp types don't match an interleave");
3328 }
3329
3330 ArrayRef<int64_t> shuffleMask = op.getMask();
3331 int64_t resultVectorSize = resultType.getNumElements();
3332 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3333 int64_t maskValueA = shuffleMask[i * 2];
3334 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3335 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3336 return rewriter.notifyMatchFailure(op,
3337 "ShuffleOp mask not interleaving");
3338 }
3339
3340 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3341 return success();
3342 }
3343};
3344
3345} // namespace
3346
3347void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3348 MLIRContext *context) {
3349 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3350 context);
3351}
3352
3353//===----------------------------------------------------------------------===//
3354// InsertOp
3355//===----------------------------------------------------------------------===//
3356
3357void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3358 SetIntRangeFn setResultRanges) {
3359 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3360}
3361
3362void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3363 Value source, Value dest) {
3364 auto vectorTy = cast<VectorType>(dest.getType());
3365 build(builder, result, source, dest,
3366 SmallVector<int64_t>(vectorTy.getRank(), 0));
3367}
3368
3369void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3370 Value source, Value dest, int64_t position) {
3371 build(builder, result, source, dest, ArrayRef<int64_t>{position});
3372}
3373
3374void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3375 Value source, Value dest, OpFoldResult position) {
3376 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3377}
3378
3379void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3380 Value source, Value dest,
3381 ArrayRef<int64_t> position) {
3382 SmallVector<OpFoldResult> posVals;
3383 posVals.reserve(position.size());
3384 llvm::transform(position, std::back_inserter(posVals),
3385 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3386 build(builder, result, source, dest, posVals);
3387}
3388
3389void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3390 Value source, Value dest,
3391 ArrayRef<OpFoldResult> position) {
3392 SmallVector<int64_t> staticPos;
3393 SmallVector<Value> dynamicPos;
3394 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3395 build(builder, result, source, dest, dynamicPos,
3396 builder.getDenseI64ArrayAttr(staticPos));
3397}
3398
3399LogicalResult InsertOp::verify() {
3400 if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3401 if (srcTy.getRank() == 0)
3402 return emitError(
3403 "expected a scalar instead of a 0-d vector as the source operand");
3404
3405 SmallVector<OpFoldResult> position = getMixedPosition();
3406 auto destVectorType = getDestVectorType();
3407 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3408 return emitOpError(
3409 "expected position attribute of rank no greater than dest vector rank");
3410 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3411 if (srcVectorType &&
3412 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3413 static_cast<unsigned>(destVectorType.getRank())))
3414 return emitOpError("expected position attribute rank + source rank to "
3415 "match dest vector rank");
3416 if (!srcVectorType &&
3417 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3418 return emitOpError(
3419 "expected position attribute rank to match the dest vector rank");
3420 for (auto [idx, pos] : llvm::enumerate(position)) {
3421 if (auto attr = dyn_cast<Attribute>(pos)) {
3422 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3423 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3424 destVectorType.getDimSize(idx))) {
3425 return emitOpError("expected position attribute #")
3426 << (idx + 1)
3427 << " to be a non-negative integer smaller than the "
3428 "corresponding "
3429 "dest vector dimension";
3430 }
3431 }
3432 }
3433 return success();
3434}
3435
3436// Calculate the linearized position of the continuous chunk of elements to
3437// insert, based on the shape of the value to insert and the positions to insert
3438// at.
3439static int64_t calculateInsertPosition(VectorType destTy,
3440 ArrayRef<int64_t> positions) {
3441 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3442 assert(positions.size() <= completePositions.size() &&
3443 "positions size must be less than or equal to destTy rank");
3444 copy(positions, completePositions.begin());
3445 return linearize(completePositions, computeStrides(destTy.getShape()));
3446}
3447
3448namespace {
3449
3450// If insertOp is only inserting unit dimensions it can be transformed to a
3451// broadcast.
3452class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3453public:
3454 using Base::Base;
3455
3456 LogicalResult matchAndRewrite(InsertOp insertOp,
3457 PatternRewriter &rewriter) const override {
3458 auto srcVecType =
3459 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3460 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3461 srcVecType.getNumElements())
3462 return failure();
3463 rewriter.replaceOpWithNewOp<BroadcastOp>(
3464 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3465 return success();
3466 }
3467};
3468
3469/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3470class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3471public:
3472 using Base::Base;
3473
3474 LogicalResult matchAndRewrite(InsertOp op,
3475 PatternRewriter &rewriter) const override {
3476
3477 Value splat = getScalarSplatSource(op.getValueToStore());
3478 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3479 return failure();
3480
3481 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3482 return success();
3483 }
3484};
3485
3486/// Pattern to optimize a chain of insertions.
3487///
3488/// This pattern identifies chains of vector.insert operations that:
3489/// 1. Only insert values at static positions.
3490/// 2. Completely initialize all elements in the resulting vector.
3491/// 3. All intermediate insert operations have only one use.
3492///
3493/// When these conditions are met, the entire chain can be replaced with a
3494/// single vector.from_elements operation.
3495///
3496/// To keep this pattern simple, and avoid spending too much time on matching
3497/// fragmented insert chains, this pattern only considers the last insert op in
3498/// the chain.
3499///
3500/// Example transformation:
3501/// %poison = ub.poison : vector<2xi32>
3502/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3503/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3504/// ->
3505/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3506class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3507public:
3508 using Base::Base;
3509 LogicalResult matchAndRewrite(InsertOp op,
3510 PatternRewriter &rewriter) const override {
3511
3512 VectorType destTy = op.getDestVectorType();
3513 if (destTy.isScalable())
3514 return failure();
3515 // Ensure this is the trailing vector.insert op in a chain of inserts.
3516 for (Operation *user : op.getResult().getUsers())
3517 if (auto insertOp = dyn_cast<InsertOp>(user))
3518 if (insertOp.getDest() == op.getResult())
3519 return failure();
3520
3521 InsertOp currentOp = op;
3522 SmallVector<InsertOp> chainInsertOps;
3523 while (currentOp) {
3524 // Check cond 1: Dynamic position is not supported.
3525 if (currentOp.hasDynamicPosition())
3526 return failure();
3527
3528 chainInsertOps.push_back(currentOp);
3529 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3530 // Check cond 3: Intermediate inserts have only one use to avoid an
3531 // explosion of vectors.
3532 if (currentOp && !currentOp->hasOneUse())
3533 return failure();
3534 }
3535
3536 int64_t vectorSize = destTy.getNumElements();
3537 int64_t initializedCount = 0;
3538 SmallVector<bool> initializedDestIdxs(vectorSize, false);
3539 SmallVector<int64_t> pendingInsertPos;
3540 SmallVector<int64_t> pendingInsertSize;
3541 SmallVector<Value> pendingInsertValues;
3542
3543 for (auto insertOp : chainInsertOps) {
3544 // This pattern can do nothing with poison index.
3545 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3546 return failure();
3547
3548 // Calculate the linearized position for inserting elements.
3549 int64_t insertBeginPosition =
3550 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3551
3552 // The valueToStore operand may be a vector or a scalar. Need to handle
3553 // both cases.
3554 int64_t insertSize = 1;
3555 if (auto srcVectorType =
3556 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3557 insertSize = srcVectorType.getNumElements();
3558
3559 assert(insertBeginPosition + insertSize <= vectorSize &&
3560 "insert would overflow the vector");
3561
3562 for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3563 insertBeginPosition + insertSize)) {
3564 if (initializedDestIdxs[index])
3565 continue;
3566 initializedDestIdxs[index] = true;
3567 ++initializedCount;
3568 }
3569
3570 // Defer the creation of ops before we can make sure the pattern can
3571 // succeed.
3572 pendingInsertPos.push_back(insertBeginPosition);
3573 pendingInsertSize.push_back(insertSize);
3574 pendingInsertValues.push_back(insertOp.getValueToStore());
3575
3576 if (initializedCount == vectorSize)
3577 break;
3578 }
3579
3580 // Check cond 2: all positions must be initialized.
3581 if (initializedCount != vectorSize)
3582 return failure();
3583
3584 SmallVector<Value> elements(vectorSize);
3585 for (auto [insertBeginPosition, insertSize, valueToStore] :
3586 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3587 pendingInsertValues))) {
3588 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3589
3590 if (!srcVectorType) {
3591 elements[insertBeginPosition] = valueToStore;
3592 continue;
3593 }
3594
3595 SmallVector<Type> elementToInsertTypes(insertSize,
3596 srcVectorType.getElementType());
3597 // Get all elements from the vector in row-major order.
3598 auto elementsToInsert = vector::ToElementsOp::create(
3599 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3600 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3601 elements[insertBeginPosition + linearIdx] =
3602 elementsToInsert.getResult(linearIdx);
3603 }
3604 }
3605
3606 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3607 return success();
3608 }
3609};
3610
3611} // namespace
3612
3613static Attribute
3615 Attribute dstAttr,
3616 int64_t maxVectorSizeFoldThreshold) {
3617 if (insertOp.hasDynamicPosition())
3618 return {};
3619
3620 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3621 if (!denseDst)
3622 return {};
3623
3624 if (!srcAttr) {
3625 return {};
3626 }
3627
3628 VectorType destTy = insertOp.getDestVectorType();
3629 if (destTy.isScalable())
3630 return {};
3631
3632 // Make sure we do not create too many large constants.
3633 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3634 !insertOp->hasOneUse())
3635 return {};
3636
3637 // Calculate the linearized position for inserting elements.
3638 int64_t insertBeginPosition =
3639 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3640 SmallVector<Attribute> insertedValues;
3641 Type destEltType = destTy.getElementType();
3642
3643 /// Converts attribute to the expected type if there's
3644 /// a mismatch.
3645 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3646 for (auto value : denseSource.getValues<Attribute>())
3647 insertedValues.push_back(convertNumericAttr(value, destEltType));
3648 } else {
3649 insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3650 }
3651
3652 auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3653 copy(insertedValues, allValues.begin() + insertBeginPosition);
3654 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3655
3656 return newAttr;
3657}
3658
3659/// Folder to replace the `dest` operand of the insert op with the root dest of
3660/// the insert op use chain.
3661static Value foldInsertUseChain(InsertOp insertOp) {
3662 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3663 if (!destInsert)
3664 return {};
3665
3666 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3667 return {};
3668
3669 insertOp.setOperand(1, destInsert.getDest());
3670 return insertOp.getResult();
3671}
3672
3673void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3674 MLIRContext *context) {
3675 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3676 InsertChainFullyInitialized>(context);
3677}
3678
3679OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3680 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3681 // unless the source vector constant has a single use.
3682 constexpr int64_t vectorSizeFoldThreshold = 256;
3683 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3684 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3685 // (type mismatch).
3686 if (getNumIndices() == 0 && getValueToStoreType() == getType())
3687 return getValueToStore();
3688 // Fold `arith.constant` indices into the `vector.insert` operation.
3689 // Do not stop here as this fold may enable subsequent folds that require
3690 // constant indices.
3691 SmallVector<Value> operands = {getValueToStore(), getDest()};
3692 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3693
3694 if (auto res = foldInsertUseChain(*this))
3695 return res;
3696 if (auto res = foldPoisonIndexInsertExtractOp(
3697 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3698 return res;
3699 if (auto res = foldDenseElementsAttrDestInsertOp(
3700 *this, adaptor.getValueToStore(), adaptor.getDest(),
3701 vectorSizeFoldThreshold)) {
3702 return res;
3703 }
3704
3705 return inplaceFolded;
3706}
3707
3708//===----------------------------------------------------------------------===//
3709// InsertStridedSliceOp
3710//===----------------------------------------------------------------------===//
3711
3712void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3713 Value source, Value dest,
3714 ArrayRef<int64_t> offsets,
3715 ArrayRef<int64_t> strides) {
3716 result.addOperands({source, dest});
3717 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3718 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3719 result.addTypes(dest.getType());
3720 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3721 offsetsAttr);
3722 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3723 stridesAttr);
3724}
3725
3726// TODO: Should be moved to Tablegen ConfinedAttr attributes.
3727template <typename OpType>
3728static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3729 ArrayAttr arrayAttr,
3731 StringRef attrName) {
3732 if (arrayAttr.size() > shape.size())
3733 return op.emitOpError("expected ")
3734 << attrName << " attribute of rank no greater than vector rank";
3735 return success();
3736}
3737
3738// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3739// interval. If `halfOpen` is true then the admissible interval is [min, max).
3740// Otherwise, the admissible interval is [min, max].
3741template <typename OpType>
3742static LogicalResult
3744 int64_t max, StringRef attrName,
3745 bool halfOpen = true) {
3746 for (auto attr : arrayAttr) {
3747 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3748 auto upper = max;
3749 if (!halfOpen)
3750 upper += 1;
3751 if (val < min || val >= upper)
3752 return op.emitOpError("expected ") << attrName << " to be confined to ["
3753 << min << ", " << upper << ")";
3754 }
3755 return success();
3756}
3757
3758// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3759// interval. If `halfOpen` is true then the admissible interval is [min, max).
3760// Otherwise, the admissible interval is [min, max].
3761template <typename OpType>
3762static LogicalResult
3764 ArrayRef<int64_t> shape, StringRef attrName,
3765 bool halfOpen = true, int64_t min = 0) {
3766 for (auto [index, attrDimPair] :
3767 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3768 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3769 int64_t max = std::get<1>(attrDimPair);
3770 if (!halfOpen)
3771 max += 1;
3772 if (val < min || val >= max)
3773 return op.emitOpError("expected ")
3774 << attrName << " dimension " << index << " to be confined to ["
3775 << min << ", " << max << ")";
3776 }
3777 return success();
3778}
3779
3780// Returns true if, for all indices i = 0..shape.size()-1, val is in the
3781// [min, max} interval:
3782// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3783// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3784// the admissible interval is [min, max].
3785template <typename OpType>
3787 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3788 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3789 bool halfOpen = true, int64_t min = 1) {
3790 assert(arrayAttr1.size() <= shape.size());
3791 assert(arrayAttr2.size() <= shape.size());
3792 for (auto [index, it] :
3793 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3794 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3795 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3796 int64_t max = std::get<2>(it);
3797 if (!halfOpen)
3798 max += 1;
3799 if (val1 + val2 < 0 || val1 + val2 >= max)
3800 return op.emitOpError("expected sum(")
3801 << attrName1 << ", " << attrName2 << ") dimension " << index
3802 << " to be confined to [" << min << ", " << max << ")";
3803 }
3804 return success();
3805}
3806
3808 MLIRContext *context) {
3809 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3810 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3811 });
3812 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3813}
3814
3815LogicalResult InsertStridedSliceOp::verify() {
3816 auto sourceVectorType = getSourceVectorType();
3817 auto destVectorType = getDestVectorType();
3818 auto offsets = getOffsetsAttr();
3819 auto strides = getStridesAttr();
3820 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3821 return emitOpError(
3822 "expected offsets of same size as destination vector rank");
3823 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3824 return emitOpError("expected strides of same size as source vector rank");
3825 if (sourceVectorType.getRank() > destVectorType.getRank())
3826 return emitOpError(
3827 "expected source rank to be no greater than destination rank");
3828
3829 auto sourceShape = sourceVectorType.getShape();
3830 auto destShape = destVectorType.getShape();
3831 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3832 destShape.size() - sourceShape.size(), 0);
3833 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3834 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3835 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3836 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3837 offName)) ||
3838 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3839 /*max=*/1, stridesName,
3840 /*halfOpen=*/false)) ||
3842 *this, offsets,
3843 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3844 offName, "source vector shape",
3845 /*halfOpen=*/false, /*min=*/1)))
3846 return failure();
3847
3848 unsigned rankDiff = destShape.size() - sourceShape.size();
3849 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3850 if (sourceVectorType.getScalableDims()[idx] !=
3851 destVectorType.getScalableDims()[idx + rankDiff]) {
3852 return emitOpError("mismatching scalable flags (at source vector idx=")
3853 << idx << ")";
3854 }
3855 if (sourceVectorType.getScalableDims()[idx]) {
3856 auto sourceSize = sourceShape[idx];
3857 auto destSize = destShape[idx + rankDiff];
3858 if (sourceSize != destSize) {
3859 return emitOpError("expected size at idx=")
3860 << idx
3861 << (" to match the corresponding base size from the input "
3862 "vector (")
3863 << sourceSize << (" vs ") << destSize << (")");
3864 }
3865 }
3866 }
3867
3868 return success();
3869}
3870
3871namespace {
3872/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
3873class FoldInsertStridedSliceSplat final
3874 : public OpRewritePattern<InsertStridedSliceOp> {
3875public:
3876 using Base::Base;
3877
3878 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3879 PatternRewriter &rewriter) const override {
3880
3881 auto dst = insertStridedSliceOp.getDest();
3882 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3883 if (!splat || getScalarSplatSource(dst) != splat)
3884 return failure();
3885
3886 rewriter.replaceOp(insertStridedSliceOp, dst);
3887 return success();
3888 }
3889};
3890
3891/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3892/// to dst.
3893class FoldInsertStridedSliceOfExtract final
3894 : public OpRewritePattern<InsertStridedSliceOp> {
3895public:
3896 using Base::Base;
3897
3898 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3899 PatternRewriter &rewriter) const override {
3900 auto extractStridedSliceOp =
3901 insertStridedSliceOp.getValueToStore()
3902 .getDefiningOp<vector::ExtractStridedSliceOp>();
3903
3904 if (!extractStridedSliceOp)
3905 return failure();
3906
3907 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3908 return failure();
3909
3910 // Check if have the same strides and offsets.
3911 if (extractStridedSliceOp.getStrides() !=
3912 insertStridedSliceOp.getStrides() ||
3913 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3914 return failure();
3915
3916 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3917 return success();
3918 }
3919};
3920
3921// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3922// ConstantOp.
3923class InsertStridedSliceConstantFolder final
3924 : public OpRewritePattern<InsertStridedSliceOp> {
3925public:
3926 using Base::Base;
3927
3928 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3929 // unless the source vector constant has a single use.
3930 static constexpr int64_t vectorSizeFoldThreshold = 256;
3931
3932 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3933 PatternRewriter &rewriter) const override {
3934 // Return if 'InsertOp' operand is not defined by a compatible vector
3935 // ConstantOp.
3936 TypedValue<VectorType> destVector = op.getDest();
3937 Attribute vectorDestCst;
3938 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3939 return failure();
3940
3941 VectorType destTy = destVector.getType();
3942 if (destTy.isScalable())
3943 return failure();
3944
3945 // Make sure we do not create too many large constants.
3946 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3947 !destVector.hasOneUse())
3948 return failure();
3949
3950 TypedValue<VectorType> sourceValue = op.getValueToStore();
3951 Attribute sourceCst;
3952 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3953 return failure();
3954
3955 // TODO: Support poison.
3956 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3957 return failure();
3958
3959 // TODO: Handle non-unit strides when they become available.
3960 if (op.hasNonUnitStrides())
3961 return failure();
3962
3963 VectorType sliceVecTy = sourceValue.getType();
3964 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3965 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3966 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3967 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3968
3969 // Calcualte the destination element indices by enumerating all slice
3970 // positions within the destination and linearizing them. The enumeration
3971 // order is lexicographic which yields a sequence of monotonically
3972 // increasing linearized position indices.
3973 // Because the destination may have higher dimensionality then the slice,
3974 // we keep track of two overlapping sets of positions and offsets.
3975 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3976 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3977 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3978 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3979 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3980 MutableArrayRef<int64_t> currSlicePosition(
3981 currDestPosition.begin() + rankDifference, currDestPosition.end());
3982 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3983 offsets.end());
3984 do {
3985 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3986 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3987 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3988 "Invalid slice element");
3989 newValues[linearizedPosition] = *sliceValuesIt;
3990 ++sliceValuesIt;
3991 } while (succeeded(
3992 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3993
3994 auto newAttr = DenseElementsAttr::get(destTy, newValues);
3995 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3996 return success();
3997 }
3998};
3999
4000} // namespace
4001
4002void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4003 RewritePatternSet &results, MLIRContext *context) {
4004 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4005 InsertStridedSliceConstantFolder>(context);
4006}
4007
4008OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4009 if (getSourceVectorType() == getDestVectorType())
4010 return getValueToStore();
4011 return {};
4012}
4013
4014//===----------------------------------------------------------------------===//
4015// OuterProductOp
4016//===----------------------------------------------------------------------===//
4017
4018/// Build an op without mask, use the type of `acc` as the return type.
4019void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4020 Value lhs, Value rhs, Value acc) {
4021 result.addOperands({lhs, rhs, acc});
4022 result.addTypes(acc.getType());
4023}
4024
4025void OuterProductOp::print(OpAsmPrinter &p) {
4026 p << " " << getLhs() << ", " << getRhs();
4027 if (getAcc()) {
4028 p << ", " << getAcc();
4029 p.printOptionalAttrDict((*this)->getAttrs());
4030 }
4031 p << " : " << getLhs().getType() << ", " << getRhs().getType();
4032}
4033
4034ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4035 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4036 Type tLHS, tRHS;
4037 if (parser.parseOperandList(operandsInfo) ||
4038 parser.parseOptionalAttrDict(result.attributes) ||
4039 parser.parseColonType(tLHS) || parser.parseComma() ||
4040 parser.parseType(tRHS))
4041 return failure();
4042 if (operandsInfo.size() < 2)
4043 return parser.emitError(parser.getNameLoc(),
4044 "expected at least 2 operands");
4045 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4046 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4047 if (!vLHS)
4048 return parser.emitError(parser.getNameLoc(),
4049 "expected vector type for operand #1");
4050
4051 VectorType resType;
4052 if (vRHS) {
4053 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4054 vRHS.getScalableDims()[0]};
4055 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4056 vLHS.getElementType(), scalableDimsRes);
4057 } else {
4058 // Scalar RHS operand
4059 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4060 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4061 scalableDimsRes);
4062 }
4063
4064 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4065 result.attributes.append(
4066 OuterProductOp::getKindAttrName(result.name),
4067 CombiningKindAttr::get(result.getContext(),
4068 OuterProductOp::getDefaultKind()));
4069 }
4070
4071 return failure(
4072 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4073 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4074 (operandsInfo.size() > 2 &&
4075 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4076 parser.addTypeToList(resType, result.types));
4077}
4078
4079LogicalResult OuterProductOp::verify() {
4080 Type tRHS = getOperandTypeRHS();
4081 VectorType vLHS = getOperandVectorTypeLHS(),
4082 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4083 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4084
4085 if (vLHS.getRank() != 1)
4086 return emitOpError("expected 1-d vector for operand #1");
4087
4088 if (vRHS) {
4089 // Proper OUTER operation.
4090 if (vRHS.getRank() != 1)
4091 return emitOpError("expected 1-d vector for operand #2");
4092 if (vRES.getRank() != 2)
4093 return emitOpError("expected 2-d vector result");
4094 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4095 return emitOpError("expected #1 operand dim to match result dim #1");
4096 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4097 return emitOpError("expected #2 operand dim to match result dim #2");
4098 if (vLHS.isScalable() && !vRHS.isScalable()) {
4099 // This restriction reflects what's currently supported in terms of
4100 // scalable vectors. However, we could relax this if there's a use case.
4101 return emitOpError(
4102 "expected either both or only #2 operand dim to be scalable");
4103 }
4104 } else {
4105 // An AXPY operation.
4106 if (vRES.getRank() != 1)
4107 return emitOpError("expected 1-d vector result");
4108 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4109 return emitOpError("expected #1 operand dim to match result dim #1");
4110 }
4111
4112 if (vACC && vACC != vRES)
4113 return emitOpError("expected operand #3 of same type as result type");
4114
4115 if (!getKindAttr()) {
4116 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
4117 "'vector.kind<add>')");
4118 }
4119
4120 // Verify supported combining kind.
4121 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4122 return emitOpError("unsupported outerproduct type");
4123
4124 return success();
4125}
4126
4127// MaskableOpInterface methods.
4128
4129/// Returns the mask type expected by this operation. Mostly used for
4130/// verification purposes. It requires the operation to be vectorized."
4131Type OuterProductOp::getExpectedMaskType() {
4132 auto vecType = this->getResultVectorType();
4133 return VectorType::get(vecType.getShape(),
4134 IntegerType::get(vecType.getContext(), /*width=*/1),
4135 vecType.getScalableDims());
4136}
4137
4138//===----------------------------------------------------------------------===//
4139// ExtractStridedSliceOp
4140//===----------------------------------------------------------------------===//
4141
4142// Inference works as follows:
4143// 1. Add 'sizes' from prefix of dims in 'offsets'.
4144// 2. Add sizes from 'vectorType' for remaining dims.
4145// Scalable flags are inherited from 'vectorType'.
4146static Type inferStridedSliceOpResultType(VectorType vectorType,
4147 ArrayAttr offsets, ArrayAttr sizes,
4148 ArrayAttr strides) {
4149 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4151 shape.reserve(vectorType.getRank());
4152 unsigned idx = 0;
4153 for (unsigned e = offsets.size(); idx < e; ++idx)
4154 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4155 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4156 shape.push_back(vectorType.getShape()[idx]);
4157
4158 return VectorType::get(shape, vectorType.getElementType(),
4159 vectorType.getScalableDims());
4160}
4161
4162void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4163 Value source, ArrayRef<int64_t> offsets,
4164 ArrayRef<int64_t> sizes,
4165 ArrayRef<int64_t> strides) {
4166 result.addOperands(source);
4167 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4168 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4169 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4170 result.addTypes(
4171 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4172 offsetsAttr, sizesAttr, stridesAttr));
4173 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4174 offsetsAttr);
4175 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4176 sizesAttr);
4177 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4178 stridesAttr);
4179}
4180
4181LogicalResult ExtractStridedSliceOp::verify() {
4182 auto type = getSourceVectorType();
4183 auto offsets = getOffsetsAttr();
4184 auto sizes = getSizesAttr();
4185 auto strides = getStridesAttr();
4186 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4187 return emitOpError(
4188 "expected offsets, sizes and strides attributes of same size");
4189
4190 auto shape = type.getShape();
4191 auto offName = getOffsetsAttrName();
4192 auto sizesName = getSizesAttrName();
4193 auto stridesName = getStridesAttrName();
4194 if (failed(
4195 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4196 failed(
4197 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4198 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4199 stridesName)) ||
4200 failed(
4201 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4202 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4203 /*halfOpen=*/false,
4204 /*min=*/1)) ||
4205 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4206 /*max=*/1, stridesName,
4207 /*halfOpen=*/false)) ||
4208 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4209 shape, offName, sizesName,
4210 /*halfOpen=*/false)))
4211 return failure();
4212
4213 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4214 offsets, sizes, strides);
4215 if (getResult().getType() != resultType)
4216 return emitOpError("expected result type to be ") << resultType;
4217
4218 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4219 if (type.getScalableDims()[idx]) {
4220 auto inputDim = type.getShape()[idx];
4221 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4222 if (inputDim != inputSize)
4223 return emitOpError("expected size at idx=")
4224 << idx
4225 << (" to match the corresponding base size from the input "
4226 "vector (")
4227 << inputSize << (" vs ") << inputDim << (")");
4228 }
4229 }
4230
4231 return success();
4232}
4233
4234// When the source of ExtractStrided comes from a chain of InsertStrided ops try
4235// to use the source of the InsertStrided ops if we can detect that the
4236// extracted vector is a subset of one of the vector inserted.
4237static LogicalResult
4238foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4239 // Helper to extract integer out of ArrayAttr.
4240 auto getElement = [](ArrayAttr array, int idx) {
4241 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4242 };
4243 ArrayAttr extractOffsets = op.getOffsets();
4244 ArrayAttr extractStrides = op.getStrides();
4245 ArrayAttr extractSizes = op.getSizes();
4246 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4247 while (insertOp) {
4248 if (op.getSourceVectorType().getRank() !=
4249 insertOp.getSourceVectorType().getRank())
4250 return failure();
4251 ArrayAttr insertOffsets = insertOp.getOffsets();
4252 ArrayAttr insertStrides = insertOp.getStrides();
4253 // If the rank of extract is greater than the rank of insert, we are likely
4254 // extracting a partial chunk of the vector inserted.
4255 if (extractOffsets.size() > insertOffsets.size())
4256 return failure();
4257 bool patialoverlap = false;
4258 bool disjoint = false;
4259 SmallVector<int64_t, 4> offsetDiffs;
4260 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4261 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4262 return failure();
4263 int64_t start = getElement(insertOffsets, dim);
4264 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4265 int64_t offset = getElement(extractOffsets, dim);
4266 int64_t size = getElement(extractSizes, dim);
4267 // Check if the start of the extract offset is in the interval inserted.
4268 if (start <= offset && offset < end) {
4269 // If the extract interval overlaps but is not fully included we may
4270 // have a partial overlap that will prevent any folding.
4271 if (offset + size > end)
4272 patialoverlap = true;
4273 offsetDiffs.push_back(offset - start);
4274 continue;
4275 }
4276 disjoint = true;
4277 break;
4278 }
4279 // The extract element chunk is a subset of the insert element.
4280 if (!disjoint && !patialoverlap) {
4281 op.setOperand(insertOp.getValueToStore());
4282 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4283 OpBuilder b(op.getContext());
4284 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4285 return success();
4286 }
4287 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4288 // in the insert chain.
4289 if (disjoint)
4290 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4291 else {
4292 // The extracted vector partially overlap the inserted vector, we cannot
4293 // fold.
4294 return failure();
4295 }
4296 }
4297 return failure();
4298}
4299
4300// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4301static OpFoldResult
4303 Attribute foldInput) {
4304
4305 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4306 if (!dense)
4307 return {};
4308
4309 // TODO: Handle non-unit strides when they become available.
4310 if (op.hasNonUnitStrides())
4311 return {};
4312
4313 VectorType sourceVecTy = op.getSourceVectorType();
4314 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4315 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4316
4317 VectorType sliceVecTy = op.getType();
4318 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4319 int64_t rank = sliceVecTy.getRank();
4320
4321 // Expand offsets and sizes to match the vector rank.
4322 SmallVector<int64_t, 4> offsets(rank, 0);
4323 copy(getI64SubArray(op.getOffsets()), offsets.begin());
4324
4325 SmallVector<int64_t, 4> sizes(sourceShape);
4326 copy(getI64SubArray(op.getSizes()), sizes.begin());
4327
4328 // Calculate the slice elements by enumerating all slice positions and
4329 // linearizing them. The enumeration order is lexicographic which yields a
4330 // sequence of monotonically increasing linearized position indices.
4331 const auto denseValuesBegin = dense.value_begin<Attribute>();
4332 SmallVector<Attribute> sliceValues;
4333 sliceValues.reserve(sliceVecTy.getNumElements());
4334 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4335 do {
4336 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4337 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4338 "Invalid index");
4339 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4340 } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4341
4342 assert(static_cast<int64_t>(sliceValues.size()) ==
4343 sliceVecTy.getNumElements() &&
4344 "Invalid number of slice elements");
4345 return DenseElementsAttr::get(sliceVecTy, sliceValues);
4346}
4347
4348OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4349 if (getSourceVectorType() == getResult().getType())
4350 return getSource();
4351 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4352 return getResult();
4353
4354 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4355 if (auto splat =
4356 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4357 return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4358
4359 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4360 return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4361}
4362
4363void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4364 populateFromInt64AttrArray(getOffsets(), results);
4365}
4366
4367namespace {
4368
4369// Pattern to rewrite nested ExtractStridedSliceOp into a single one.
4370//
4371// Example:
4372//
4373// %0 = vector.extract_strided_slice %arg0
4374// {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]}
4375// : vector<4x8x16xf32> to vector<3x4x16xf32>
4376// %1 = vector.extract_strided_slice %0
4377// {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
4378// : vector<3x4x16xf32> to vector<2x2x16xf32>
4379//
4380// to
4381//
4382// %1 = vector.extract_strided_slice %arg0
4383// {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
4384// : vector<4x8x16xf32> to vector<2x2x16xf32>
4385class StridedSliceFolder final
4386 : public OpRewritePattern<ExtractStridedSliceOp> {
4387public:
4388 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4389
4390 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4391 PatternRewriter &rewriter) const override {
4392 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4393 if (!firstOp)
4394 return failure();
4395
4396 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4397 return failure();
4398
4399 SmallVector<int64_t> firstOffsets = getI64SubArray(firstOp.getOffsets());
4400 SmallVector<int64_t> firstSizes = getI64SubArray(firstOp.getSizes());
4401 SmallVector<int64_t> secondOffsets = getI64SubArray(secondOp.getOffsets());
4402 SmallVector<int64_t> secondSizes = getI64SubArray(secondOp.getSizes());
4403
4404 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4405 SmallVector<int64_t> combinedOffsets(newRank, 0);
4406 SmallVector<int64_t> combinedSizes(newRank);
4407 ArrayRef<int64_t> firstSourceShape =
4408 firstOp.getSourceVectorType().getShape();
4409 for (unsigned i = 0; i < newRank; ++i) {
4410 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4411 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4412 combinedOffsets[i] = off1 + off2;
4413
4414 if (i < secondSizes.size()) {
4415 combinedSizes[i] = secondSizes[i];
4416 } else if (i < firstSizes.size()) {
4417 combinedSizes[i] = firstSizes[i];
4418 } else {
4419 combinedSizes[i] = firstSourceShape[i];
4420 }
4421 }
4422
4423 SmallVector<int64_t> combinedStrides(newRank, 1);
4424 rewriter.replaceOpWithNewOp<ExtractStridedSliceOp>(
4425 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4426 combinedStrides);
4427 return success();
4428 }
4429};
4430
4431// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4432// CreateMaskOp.
4433//
4434// Example:
4435//
4436// %mask = vector.create_mask %ub : vector<16xi1>
4437// %slice = vector.extract_strided_slice [%offset] [8] [1]
4438//
4439// to
4440//
4441// %new_ub = arith.subi %ub, %offset
4442// %mask = vector.create_mask %new_ub : vector<8xi1>
4443class StridedSliceCreateMaskFolder final
4444 : public OpRewritePattern<ExtractStridedSliceOp> {
4445 using Base::Base;
4446
4447public:
4448 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4449 PatternRewriter &rewriter) const override {
4450 Location loc = extractStridedSliceOp.getLoc();
4451 // Return if 'extractStridedSliceOp' operand is not defined by a
4452 // CreateMaskOp.
4453 auto createMaskOp =
4454 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4455 if (!createMaskOp)
4456 return failure();
4457 // Return if 'extractStridedSliceOp' has non-unit strides.
4458 if (extractStridedSliceOp.hasNonUnitStrides())
4459 return failure();
4460 // Gather constant mask dimension sizes.
4461 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4462 // Gather strided slice offsets and sizes.
4463 SmallVector<int64_t> sliceOffsets;
4464 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4465 sliceOffsets);
4466 SmallVector<int64_t> sliceSizes;
4467 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4468
4469 // Compute slice of vector mask region.
4470 SmallVector<Value> sliceMaskDimSizes;
4471 sliceMaskDimSizes.reserve(maskDimSizes.size());
4472 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4473 // only iterate on the leading dim sizes. The tail accounts for the
4474 // remaining dim sizes.
4475 for (auto [maskDimSize, sliceOffset, sliceSize] :
4476 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4477 // No need to clamp on min/max values, because create_mask has clamping
4478 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4479 // greater than the vector dim size.
4480 IntegerAttr offsetAttr =
4481 rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4482 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4483 Value sliceMaskDimSize =
4484 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4485 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4486 }
4487 // Add unchanged dimensions.
4488 llvm::append_range(
4489 sliceMaskDimSizes,
4490 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4491 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4492 // region.
4493 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4494 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4495 sliceMaskDimSizes);
4496 return success();
4497 }
4498};
4499
4500// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4501// ConstantMaskOp.
4502class StridedSliceConstantMaskFolder final
4503 : public OpRewritePattern<ExtractStridedSliceOp> {
4504public:
4505 using Base::Base;
4506
4507 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4508 PatternRewriter &rewriter) const override {
4509 // Return if 'extractStridedSliceOp' operand is not defined by a
4510 // ConstantMaskOp.
4511 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4512 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4513 if (!constantMaskOp)
4514 return failure();
4515 // Return if 'extractStridedSliceOp' has non-unit strides.
4516 if (extractStridedSliceOp.hasNonUnitStrides())
4517 return failure();
4518 // Gather constant mask dimension sizes.
4519 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4520 // Gather strided slice offsets and sizes.
4521 SmallVector<int64_t> sliceOffsets;
4522 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4523 sliceOffsets);
4524 SmallVector<int64_t> sliceSizes;
4525 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4526
4527 // Compute slice of vector mask region.
4528 SmallVector<int64_t> sliceMaskDimSizes;
4529 sliceMaskDimSizes.reserve(maskDimSizes.size());
4530 for (auto [maskDimSize, sliceOffset, sliceSize] :
4531 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4532 int64_t sliceMaskDimSize = std::max(
4533 static_cast<int64_t>(0),
4534 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4535 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4536 }
4537 // Add unchanged dimensions.
4538 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4539 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4540 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4541 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4542 // region is a conjunction of mask dim intervals).
4543 if (llvm::is_contained(sliceMaskDimSizes, 0))
4544 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4545
4546 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4547 // region.
4548 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4549 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4550 sliceMaskDimSizes);
4551 return success();
4552 }
4553};
4554
4555// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4556// BroadcastOp(ExtractStrideSliceOp).
4557class StridedSliceBroadcast final
4558 : public OpRewritePattern<ExtractStridedSliceOp> {
4559public:
4560 using Base::Base;
4561
4562 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4563 PatternRewriter &rewriter) const override {
4564 auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4565 if (!broadcast)
4566 return failure();
4567 auto srcVecType =
4568 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4569 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4570 auto dstVecType = llvm::cast<VectorType>(op.getType());
4571 unsigned dstRank = dstVecType.getRank();
4572 unsigned rankDiff = dstRank - srcRank;
4573 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4574 // (n -> m with n > m). If they are originally both broadcasted *and*
4575 // sliced, this can be simplified to just broadcasting.
4576 bool needsSlice = false;
4577 for (unsigned i = 0; i < srcRank; i++) {
4578 if (srcVecType.getDimSize(i) != 1 &&
4579 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4580 needsSlice = true;
4581 break;
4582 }
4583 }
4584 Value source = broadcast.getSource();
4585 if (needsSlice) {
4586 SmallVector<int64_t> offsets =
4587 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4588 SmallVector<int64_t> sizes =
4589 getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4590 for (unsigned i = 0; i < srcRank; i++) {
4591 if (srcVecType.getDimSize(i) == 1) {
4592 // In case this dimension was broadcasted *and* sliced, the offset
4593 // and size need to be updated now that there is no broadcast before
4594 // the slice.
4595 offsets[i] = 0;
4596 sizes[i] = 1;
4597 }
4598 }
4599 source = ExtractStridedSliceOp::create(
4600 rewriter, op->getLoc(), source, offsets, sizes,
4601 getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4602 }
4603 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4604 return success();
4605 }
4606};
4607
4608/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4609class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4610public:
4611 using Base::Base;
4612
4613 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4614 PatternRewriter &rewriter) const override {
4615
4616 Value splat = getScalarSplatSource(op.getSource());
4617 if (!splat)
4618 return failure();
4619 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4620 return success();
4621 }
4622};
4623
4624/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4625/// slice is contiguous, into extract and shape_cast.
4626///
4627/// Example:
4628/// Before:
4629/// %1 = vector.extract_strided_slice %arg0 {
4630/// offsets = [0, 0, 0, 0, 0],
4631/// sizes = [1, 1, 1, 1, 8],
4632/// strides = [1, 1, 1, 1, 1]
4633/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4634/// After:
4635/// %0 = vector.extract %arg0[0, 0, 0, 0]
4636/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4637/// %1 = vector.shape_cast %0
4638/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4639///
4640class ContiguousExtractStridedSliceToExtract final
4641 : public OpRewritePattern<ExtractStridedSliceOp> {
4642public:
4643 using Base::Base;
4644
4645 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4646 PatternRewriter &rewriter) const override {
4647 if (op.hasNonUnitStrides())
4648 return failure();
4649 Value source = op.getOperand();
4650 auto sourceType = cast<VectorType>(source.getType());
4651 if (sourceType.isScalable() || sourceType.getRank() == 0)
4652 return failure();
4653
4654 // Compute the number of offsets to pass to ExtractOp::build. That is the
4655 // difference between the source rank and the desired slice rank. We walk
4656 // the dimensions from innermost out, and stop when the next slice dimension
4657 // is not full-size.
4658 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4659 int numOffsets;
4660 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4661 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4662 break;
4663 }
4664
4665 // If the created extract op would have no offsets, then this whole
4666 // extract_strided_slice is the identity and should have been handled by
4667 // other canonicalizations.
4668 if (numOffsets == 0)
4669 return failure();
4670
4671 // If not even the inner-most dimension is full-size, this op can't be
4672 // rewritten as an ExtractOp.
4673 if (numOffsets == sourceType.getRank() &&
4674 static_cast<int>(sizes.size()) == sourceType.getRank())
4675 return failure();
4676
4677 // The outer dimensions must have unit size.
4678 for (int i = 0; i < numOffsets; ++i) {
4679 if (sizes[i] != 1)
4680 return failure();
4681 }
4682
4683 // Avoid generating slices that have leading unit dimensions. The shape_cast
4684 // op that we create below would take bad generic fallback patterns
4685 // (ShapeCastOpRewritePattern).
4686 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4687 sizes[numOffsets] == 1) {
4688 ++numOffsets;
4689 }
4690
4691 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4692 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4693 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4694 extractOffsets);
4695 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4696 return success();
4697 }
4698};
4699
4700} // namespace
4701
4702void ExtractStridedSliceOp::getCanonicalizationPatterns(
4703 RewritePatternSet &results, MLIRContext *context) {
4704 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4705 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4706 results.add<StridedSliceFolder, StridedSliceCreateMaskFolder,
4707 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
4708 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4709 context);
4710}
4711
4712//===----------------------------------------------------------------------===//
4713// TransferReadOp
4714//===----------------------------------------------------------------------===//
4715
4716/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4717void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4718 VectorType vectorType, Value source,
4719 ValueRange indices, std::optional<Value> padding,
4720 AffineMapAttr permutationMapAttr,
4721 /*optional*/ ArrayAttr inBoundsAttr) {
4722
4723 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4724 if (!padding)
4725 padding = ub::PoisonOp::create(builder, result.location, elemType);
4726 build(builder, result, vectorType, source, indices, permutationMapAttr,
4727 *padding, /*mask=*/Value(), inBoundsAttr);
4728}
4729
4730/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4731void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4732 VectorType vectorType, Value source,
4733 ValueRange indices, std::optional<Value> padding,
4734 AffineMap permutationMap,
4735 std::optional<ArrayRef<bool>> inBounds) {
4736 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4737 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4738 ? builder.getBoolArrayAttr(inBounds.value())
4739 : builder.getBoolArrayAttr(
4740 SmallVector<bool>(vectorType.getRank(), false));
4741 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4742 if (!padding)
4743 padding = ub::PoisonOp::create(builder, result.location, elemType);
4744 build(builder, result, vectorType, source, indices, *padding,
4745 permutationMapAttr, inBoundsAttr);
4746}
4747
4748/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4749void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4750 VectorType vectorType, Value source,
4751 ValueRange indices, std::optional<Value> padding,
4752 std::optional<ArrayRef<bool>> inBounds) {
4753 AffineMap permutationMap = getTransferMinorIdentityMap(
4754 llvm::cast<ShapedType>(source.getType()), vectorType);
4755 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4756 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4757 ? builder.getBoolArrayAttr(inBounds.value())
4758 : builder.getBoolArrayAttr(
4759 SmallVector<bool>(vectorType.getRank(), false));
4760 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4761 if (!padding)
4762 padding = ub::PoisonOp::create(builder, result.location, elemType);
4763 build(builder, result, vectorType, source, indices, permutationMapAttr,
4764 *padding,
4765 /*mask=*/Value(), inBoundsAttr);
4766}
4767
4768template <typename EmitFun>
4769static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4770 EmitFun emitOpError) {
4771 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4772 for (auto expr : permutationMap.getResults()) {
4773 auto dim = dyn_cast<AffineDimExpr>(expr);
4774 auto zero = dyn_cast<AffineConstantExpr>(expr);
4775 if (zero) {
4776 if (zero.getValue() != 0) {
4777 return emitOpError(
4778 "requires a projected permutation_map (at most one dim or the zero "
4779 "constant can appear in each result)");
4780 }
4781 continue;
4782 }
4783 if (!dim) {
4784 return emitOpError("requires a projected permutation_map (at most one "
4785 "dim or the zero constant can appear in each result)");
4786 }
4787 if (seen[dim.getPosition()]) {
4788 return emitOpError(
4789 "requires a permutation_map that is a permutation (found one dim "
4790 "used more than once)");
4791 }
4792 seen[dim.getPosition()] = true;
4793 }
4794 return success();
4795}
4796
4797static LogicalResult
4798verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4799 VectorType vectorType, VectorType maskType,
4800 VectorType inferredMaskType, AffineMap permutationMap,
4801 ArrayAttr inBounds) {
4802 if (op->hasAttr("masked")) {
4803 return op->emitOpError("masked attribute has been removed. "
4804 "Use in_bounds instead.");
4805 }
4806
4807 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4808 return op->emitOpError(
4809 "requires source to be a memref or ranked tensor type");
4810
4811 auto elementType = shapedType.getElementType();
4812 DataLayout dataLayout = DataLayout::closest(op);
4813 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4814 // Memref or tensor has vector element type.
4815 unsigned sourceVecSize =
4816 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4817 vectorElementType.getShape().back();
4818 unsigned resultVecSize =
4819 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4820 vectorType.getShape().back();
4821 if (resultVecSize % sourceVecSize != 0)
4822 return op->emitOpError(
4823 "requires the bitwidth of the minor 1-D vector to be an integral "
4824 "multiple of the bitwidth of the minor 1-D vector of the source");
4825
4826 unsigned sourceVecEltRank = vectorElementType.getRank();
4827 unsigned resultVecRank = vectorType.getRank();
4828 if (sourceVecEltRank > resultVecRank)
4829 return op->emitOpError(
4830 "requires source vector element and vector result ranks to match.");
4831 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4832 // Check that permutation map results match 'rankOffset' of vector type.
4833 if (permutationMap.getNumResults() != rankOffset)
4834 return op->emitOpError("requires a permutation_map with result dims of "
4835 "the same rank as the vector type");
4836
4837 if (maskType)
4838 return op->emitOpError("does not support masks with vector element type");
4839 } else {
4840 // Memref or tensor has scalar element type.
4841 unsigned minorSize =
4842 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4843 unsigned resultVecSize =
4844 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4845 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4846 return op->emitOpError(
4847 "requires the bitwidth of the minor 1-D vector to be an integral "
4848 "multiple of the bitwidth of the source element type");
4849
4850 // Check that permutation map results match rank of vector type.
4851 if (permutationMap.getNumResults() != vectorType.getRank())
4852 return op->emitOpError("requires a permutation_map with result dims of "
4853 "the same rank as the vector type");
4854 }
4855
4856 if (permutationMap.getNumSymbols() != 0)
4857 return op->emitOpError("requires permutation_map without symbols");
4858
4859 if (permutationMap.getNumInputs() != shapedType.getRank())
4860 return op->emitOpError("requires a permutation_map with input dims of the "
4861 "same rank as the source type");
4862
4863 if (maskType && maskType != inferredMaskType)
4864 return op->emitOpError("inferred mask type (")
4865 << inferredMaskType << ") and mask operand type (" << maskType
4866 << ") don't match";
4867
4868 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4869 return op->emitOpError("expects the in_bounds attr of same rank "
4870 "as permutation_map results: ")
4871 << AffineMapAttr::get(permutationMap)
4872 << " vs inBounds of size: " << inBounds.size();
4873
4874 return success();
4875}
4876
4877static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4878 SmallVector<StringRef, 3> elidedAttrs;
4879 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4880 if (op.getPermutationMap().isMinorIdentity())
4881 elidedAttrs.push_back(op.getPermutationMapAttrName());
4882 // Elide in_bounds attribute if all dims are out-of-bounds.
4883 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4884 elidedAttrs.push_back(op.getInBoundsAttrName());
4885 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4886}
4887
4888void TransferReadOp::print(OpAsmPrinter &p) {
4889 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
4890 if (getMask())
4891 p << ", " << getMask();
4892 printTransferAttrs(p, *this);
4893 p << " : " << getShapedType() << ", " << getVectorType();
4894}
4895
4896VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4897 AffineMap permMap) {
4898 auto i1Type = IntegerType::get(permMap.getContext(), 1);
4899 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4900 assert(invPermMap && "Inversed permutation map couldn't be computed");
4901 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4902
4903 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4904 // 0-D mask into a single-element 1-D mask.
4905 if (maskShape.empty())
4906 maskShape.push_back(1);
4907
4908 SmallVector<bool> scalableDims =
4909 applyPermutationMap(invPermMap, vecType.getScalableDims());
4910
4911 return VectorType::get(maskShape, i1Type, scalableDims);
4912}
4913
4914ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4915 auto &builder = parser.getBuilder();
4916 SMLoc typesLoc;
4922 // Parsing with support for paddingValue.
4923 if (parser.parseOperand(sourceInfo) ||
4925 parser.parseComma() || parser.parseOperand(paddingInfo))
4926 return failure();
4927 ParseResult hasMask = parser.parseOptionalComma();
4928 if (hasMask.succeeded()) {
4929 if (parser.parseOperand(maskInfo))
4930 return failure();
4931 }
4932 if (parser.parseOptionalAttrDict(result.attributes) ||
4933 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4934 return failure();
4935 if (types.size() != 2)
4936 return parser.emitError(typesLoc, "requires two types");
4937 auto indexType = builder.getIndexType();
4938 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4939 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4940 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4941 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4942 if (!vectorType)
4943 return parser.emitError(typesLoc, "requires vector type");
4944 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4945 Attribute permMapAttr = result.attributes.get(permMapAttrName);
4946 AffineMap permMap;
4947 if (!permMapAttr) {
4948 if (shapedType.getRank() <
4949 getEffectiveVectorRankForXferOp(shapedType, vectorType))
4950 return parser.emitError(typesLoc,
4951 "expected a custom permutation_map when "
4952 "rank(source) != rank(destination)");
4953 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4954 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4955 } else {
4956 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4957 }
4958 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4959 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4960 if (!inBoundsAttr) {
4961 result.addAttribute(inBoundsAttrName,
4962 builder.getBoolArrayAttr(
4963 SmallVector<bool>(permMap.getNumResults(), false)));
4964 }
4965 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4966 parser.resolveOperands(indexInfo, indexType, result.operands) ||
4967 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4968 result.operands))
4969 return failure();
4970 if (hasMask.succeeded()) {
4971 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4972 return parser.emitError(
4973 maskInfo.location, "does not support masks with vector element type");
4974 if (vectorType.getRank() != permMap.getNumResults()) {
4975 return parser.emitError(typesLoc,
4976 "expected the same rank for the vector and the "
4977 "results of the permutation map");
4978 }
4979 // Instead of adding the mask type as an op type, compute it based on the
4980 // vector type and the permutation map (to keep the type signature small).
4981 auto maskType = inferTransferOpMaskType(vectorType, permMap);
4982 if (parser.resolveOperand(maskInfo, maskType, result.operands))
4983 return failure();
4984 }
4985 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4986 builder.getDenseI32ArrayAttr(
4987 {1, static_cast<int32_t>(indexInfo.size()), 1,
4988 static_cast<int32_t>(hasMask.succeeded())}));
4989 return parser.addTypeToList(vectorType, result.types);
4990}
4991
4992LogicalResult TransferReadOp::verify() {
4993 // Consistency of elemental types in source and vector.
4994 ShapedType shapedType = getShapedType();
4995 VectorType vectorType = getVectorType();
4996 VectorType maskType = getMaskType();
4997 auto paddingType = getPadding().getType();
4998 auto permutationMap = getPermutationMap();
4999 VectorType inferredMaskType =
5000 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5001 : VectorType();
5002 auto sourceElementType = shapedType.getElementType();
5003
5004 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
5005 return emitOpError("requires ") << shapedType.getRank() << " indices";
5006
5007 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5008 shapedType, vectorType, maskType,
5009 inferredMaskType, permutationMap, getInBounds())))
5010 return failure();
5011
5012 if (auto sourceVectorElementType =
5013 llvm::dyn_cast<VectorType>(sourceElementType)) {
5014 // Source has vector element type.
5015 // Check that 'sourceVectorElementType' and 'paddingType' types match.
5016 if (sourceVectorElementType != paddingType)
5017 return emitOpError(
5018 "requires source element type and padding type to match.");
5019
5020 } else {
5021 // Check that 'paddingType' is valid to store in a vector type.
5022 if (!VectorType::isValidElementType(paddingType))
5023 return emitOpError("requires valid padding vector elemental type");
5024
5025 // Check that padding type and vector element types match.
5026 if (paddingType != sourceElementType)
5027 return emitOpError(
5028 "requires formal padding and source of the same elemental type");
5029 }
5030
5031 return verifyPermutationMap(permutationMap,
5032 [&](Twine t) { return emitOpError(t); });
5033}
5034
5035// MaskableOpInterface methods.
5036
5037/// Returns the mask type expected by this operation. Mostly used for
5038/// verification purposes. It requires the operation to be vectorized."
5039Type TransferReadOp::getExpectedMaskType() {
5040 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5041}
5042
5043//===----------------------------------------------------------------------===//
5044// TransferReadOp: VectorTransferOpInterface methods.
5045//===----------------------------------------------------------------------===//
5046VectorType TransferReadOp::getVectorType() {
5047 return cast<VectorType>(getVector().getType());
5048}
5049
5050template <typename TransferOp>
5051static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
5052 // TODO: support more aggressive createOrFold on:
5053 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
5054 if (op.getShapedType().isDynamicDim(indicesIdx))
5055 return false;
5056 Value index = op.getIndices()[indicesIdx];
5057 std::optional<int64_t> cstOp = getConstantIntValue(index);
5058 if (!cstOp.has_value())
5059 return false;
5060
5061 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5062 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5063
5064 return cstOp.value() + vectorSize <= sourceSize;
5065}
5066
5067template <typename TransferOp>
5068static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
5069 // TODO: support 0-d corner case.
5070 // TODO: Be less conservative.
5071 if (op.getTransferRank() == 0)
5072 return failure();
5073 AffineMap permutationMap = op.getPermutationMap();
5074 bool changed = false;
5075 SmallVector<bool, 4> newInBounds;
5076 newInBounds.reserve(op.getTransferRank());
5077 // Idxs of non-bcast dims - used when analysing bcast dims.
5078 SmallVector<unsigned> nonBcastDims;
5079
5080 // 1. Process non-broadcast dims
5081 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5082 // 1.1. Already marked as in-bounds, nothing to see here.
5083 if (op.isDimInBounds(i)) {
5084 newInBounds.push_back(true);
5085 continue;
5086 }
5087 // 1.2. Currently out-of-bounds, check whether we can statically determine
5088 // it is inBounds.
5089 bool inBounds = false;
5090 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5091 if (dimExpr) {
5092 inBounds = isInBounds(op, /*resultIdx=*/i,
5093 /*indicesIdx=*/dimExpr.getPosition());
5094 nonBcastDims.push_back(i);
5095 }
5096
5097 newInBounds.push_back(inBounds);
5098 // We commit the pattern if it is "more inbounds".
5099 changed |= inBounds;
5100 }
5101
5102 // 2. Handle broadcast dims
5103 // If all non-broadcast dims are "in bounds", then all bcast dims should be
5104 // "in bounds" as well.
5105 bool allNonBcastDimsInBounds = llvm::all_of(
5106 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5107 if (allNonBcastDimsInBounds) {
5108 for (size_t idx : permutationMap.getBroadcastDims()) {
5109 changed |= !newInBounds[idx];
5110 newInBounds[idx] = true;
5111 }
5112 }
5113
5114 if (!changed)
5115 return failure();
5116 // OpBuilder is only used as a helper to build an I64ArrayAttr.
5117 OpBuilder b(op.getContext());
5118 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5119 return success();
5120}
5121
5122template <typename TransferOp>
5123static LogicalResult foldTransferFullMask(TransferOp op) {
5124 auto mask = op.getMask();
5125 if (!mask)
5126 return failure();
5127
5129 return failure();
5130
5131 op.getMaskMutable().clear();
5132 return success();
5133}
5134
5135/// ```
5136/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5137/// : vector<1x4xf32>, tensor<4x4xf32>
5138/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5139/// : tensor<4x4xf32>, vector<1x4xf32>
5140/// ```
5141/// -> Folds into
5142/// ```
5143/// %v0
5144/// ```
5145static Value foldRAW(TransferReadOp readOp) {
5146 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5147 return {};
5148 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5149 while (defWrite) {
5150 if (checkSameValueRAW(defWrite, readOp))
5151 return defWrite.getVector();
5153 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5154 cast<VectorTransferOpInterface>(readOp.getOperation())))
5155 break;
5156 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5157 }
5158 return {};
5159}
5160
5161OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5162 if (Value vec = foldRAW(*this))
5163 return vec;
5164 /// transfer_read(memrefcast) -> transfer_read
5165 if (succeeded(foldTransferInBoundsAttribute(*this)))
5166 return getResult();
5167 if (succeeded(foldTransferFullMask(*this)))
5168 return getResult();
5169 if (succeeded(memref::foldMemRefCast(*this)))
5170 return getResult();
5171 if (succeeded(tensor::foldTensorCast(*this)))
5172 return getResult();
5173 return OpFoldResult();
5174}
5175
5176std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5177 return llvm::to_vector<4>(getVectorType().getShape());
5178}
5179
5180void TransferReadOp::getEffects(
5181 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5182 &effects) {
5183 if (llvm::isa<MemRefType>(getShapedType()))
5184 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5185 SideEffects::DefaultResource::get());
5186}
5187
5188Speculation::Speculatability TransferReadOp::getSpeculatability() {
5189 if (hasPureTensorSemantics())
5192}
5193
5194/// Given a projected permutation, inverse an affine map, making the unused dims
5195/// 0 in the result.
5196static AffineMap inverseWithUnusedDims(AffineMap map) {
5197 assert(map.isProjectedPermutation() &&
5198 "expected a projected permutation map");
5199 SmallVector<AffineExpr> results(map.getNumInputs(),
5201 for (auto [idx, result] : llvm::enumerate(map.getResults())) {
5202 // We should only have dim exprs because this is a projected permutation.
5203 int64_t pos = cast<AffineDimExpr>(result).getPosition();
5204 results[pos] = getAffineDimExpr(idx, map.getContext());
5205 }
5206 return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
5207 results, map.getContext());
5208}
5209
5210namespace {
5211/// Store to load forwarding for transfer operations with permuation maps.
5212/// Even if the permutation maps are different we can still propagate the store
5213/// into the load if the size of the dimensions read and written match. Then we
5214/// can replace the transfer_read + transfer_write by vector.broadcast and
5215/// vector.transpose.
5216/// Example:
5217/// ```
5218/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5219/// {in_bounds = [true, true],
5220/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5221/// vector<4x1xf32>, tensor<4x4x4xf32>
5222/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5223/// {in_bounds = [true, true, true, true],
5224/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5225/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5226/// ```
5227/// To:
5228/// ```
5229/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5230/// %r = vector.transpose %0, [3, 0, 2, 1] :
5231/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5232/// ```
5233struct TransferReadAfterWriteToBroadcast
5234 : public OpRewritePattern<TransferReadOp> {
5235 using Base::Base;
5236
5237 LogicalResult matchAndRewrite(TransferReadOp readOp,
5238 PatternRewriter &rewriter) const override {
5239 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5240 if (!defWrite)
5241 return failure();
5242 // Bail if we need an alias analysis.
5243 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5244 return failure();
5245 // Bail in the masked case (too complex atm and needed to properly account
5246 // for padding).
5247 if (readOp.getMask() || defWrite.getMask())
5248 return failure();
5249 // If indices are not the same a shift may be required, bail.
5250 if (readOp.getIndices() != defWrite.getIndices())
5251 return failure();
5252 // Bail if we need a bounds analysis.
5253 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5254 return failure();
5255 // TODO: If the written transfer chunk is a superset of the read transfer
5256 // chunk we could do an extract_strided_slice.
5257 if (readOp.getTransferChunkAccessed() !=
5258 defWrite.getTransferChunkAccessed())
5259 return failure();
5260 // WriteMap: tensor -> w_vec
5261 // ReadMap: tensor -> r_vec
5262 //
5263 // inv(WriteMap): w_vec -> tensor
5264 // inv(WriteMap) o ReadMap: w_vec -> r_vec
5265 AffineMap readMap = readOp.getPermutationMap();
5266 AffineMap writeMap = defWrite.getPermutationMap();
5267 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5268 AffineMap composedMap = readMap.compose(invWriteMap);
5269 // If there are any unused dims in the composedMap, we have to drop some
5270 // unit dims from the written vector before we can do transpose(broadcast).
5271 // TODO: Support this case.
5272 if (getUnusedDimsBitVector(composedMap).any())
5273 return failure();
5274 // readVec = transpose(broadcast(writeVec))
5275 //
5276 // Build a transpose permutation for the above transpose operation.
5277 //
5278 // Treat the composed map as having extra leading dimensions which are
5279 // the broadcasted dimensions, and treat the zeros as these new broadcasted
5280 // dimensions.
5281 SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
5282 int64_t numBroadcastedDims = broadcastedDims.size();
5283 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5284 invPerm.resize(composedMap.getNumResults());
5285 for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
5286 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5287 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5288 invPerm[effectiveDim] = idx;
5289 }
5290 }
5291 // Applying the inverse permutation on the readVecTy will give us the
5292 // broadcast result type.
5293 VectorType readVecTy = readOp.getVectorType();
5294 SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
5295 auto broadcastedVecTy =
5296 VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
5297 readVecTy.getElementType(),
5298 applyPermutation(readVecTy.getScalableDims(), invPerm));
5299 // Build the transpose(broadcast) transformation.
5300 Value vec = defWrite.getVector();
5301 Location loc = readOp.getLoc();
5302 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5303 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
5304 return success();
5305 }
5306};
5307} // namespace
5308
5309void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5310 MLIRContext *context) {
5311 results.add<TransferReadAfterWriteToBroadcast>(context);
5312}
5313
5314FailureOr<std::optional<SmallVector<Value>>>
5315TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5316 if (!hasPureBufferSemantics())
5317 return failure();
5319 getResult());
5320}
5321
5322//===----------------------------------------------------------------------===//
5323// TransferWriteOp
5324//===----------------------------------------------------------------------===//
5325
5326/// 1. Builder with type inference.
5327void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5328 Value vector, Value dest, ValueRange indices,
5329 AffineMapAttr permutationMapAttr,
5330 /*optional*/ Value mask,
5331 /*optional*/ ArrayAttr inBoundsAttr) {
5332 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5333 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5334 mask, inBoundsAttr);
5335}
5336
5337/// 2. Builder with type inference that sets an empty mask (variant with attrs).
5338void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5339 Value vector, Value dest, ValueRange indices,
5340 AffineMapAttr permutationMapAttr,
5341 /*optional*/ ArrayAttr inBoundsAttr) {
5342 build(builder, result, vector, dest, indices, permutationMapAttr,
5343 /*mask=*/Value(), inBoundsAttr);
5344}
5345
5346/// 3. Builder with type inference that sets an empty mask (variant without
5347/// attrs)
5348void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5349 Value vector, Value dest, ValueRange indices,
5350 AffineMap permutationMap,
5351 std::optional<ArrayRef<bool>> inBounds) {
5352 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5353 auto inBoundsAttr =
5354 (inBounds && !inBounds.value().empty())
5355 ? builder.getBoolArrayAttr(inBounds.value())
5356 : builder.getBoolArrayAttr(SmallVector<bool>(
5357 llvm::cast<VectorType>(vector.getType()).getRank(), false));
5358 build(builder, result, vector, dest, indices, permutationMapAttr,
5359 /*mask=*/Value(), inBoundsAttr);
5360}
5361
5362/// 4. Builder with type inference that sets an empty mask and sets permutation
5363/// map to 'getMinorIdentityMap'.
5364void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5365 Value vector, Value dest, ValueRange indices,
5366 std::optional<ArrayRef<bool>> inBounds) {
5367 auto vectorType = llvm::cast<VectorType>(vector.getType());
5368 AffineMap permutationMap = getTransferMinorIdentityMap(
5369 llvm::cast<ShapedType>(dest.getType()), vectorType);
5370 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5371}
5372
5373ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5374 OperationState &result) {
5375 auto &builder = parser.getBuilder();
5376 SMLoc typesLoc;
5377 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5378 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5379 SmallVector<Type, 2> types;
5380 OpAsmParser::UnresolvedOperand maskInfo;
5381 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5382 parser.parseOperand(sourceInfo) ||
5383 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5384 return failure();
5385 ParseResult hasMask = parser.parseOptionalComma();
5386 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5387 return failure();
5388 if (parser.parseOptionalAttrDict(result.attributes) ||
5389 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5390 return failure();
5391 if (types.size() != 2)
5392 return parser.emitError(typesLoc, "requires two types");
5393 auto indexType = builder.getIndexType();
5394 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5395 if (!vectorType)
5396 return parser.emitError(typesLoc, "requires vector type");
5397 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5398 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5399 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5400 auto permMapAttrName =
5401 TransferWriteOp::getPermutationMapAttrName(result.name);
5402 auto permMapAttr = result.attributes.get(permMapAttrName);
5403 AffineMap permMap;
5404 if (!permMapAttr) {
5405 if (shapedType.getRank() <
5406 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5407 return parser.emitError(typesLoc,
5408 "expected a custom permutation_map when "
5409 "rank(source) != rank(destination)");
5410 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5411 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5412 } else {
5413 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5414 }
5415 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5416 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5417 if (!inBoundsAttr) {
5418 result.addAttribute(inBoundsAttrName,
5419 builder.getBoolArrayAttr(
5420 SmallVector<bool>(permMap.getNumResults(), false)));
5421 }
5422 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5423 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5424 parser.resolveOperands(indexInfo, indexType, result.operands))
5425 return failure();
5426 if (hasMask.succeeded()) {
5427 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5428 return parser.emitError(
5429 maskInfo.location, "does not support masks with vector element type");
5430 if (vectorType.getRank() != permMap.getNumResults()) {
5431 return parser.emitError(typesLoc,
5432 "expected the same rank for the vector and the "
5433 "results of the permutation map");
5434 }
5435 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5436 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5437 return failure();
5438 }
5439 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5440 builder.getDenseI32ArrayAttr(
5441 {1, 1, static_cast<int32_t>(indexInfo.size()),
5442 static_cast<int32_t>(hasMask.succeeded())}));
5443 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5444 parser.addTypeToList(shapedType, result.types));
5445}
5446
5447void TransferWriteOp::print(OpAsmPrinter &p) {
5448 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5449 if (getMask())
5450 p << ", " << getMask();
5451 printTransferAttrs(p, *this);
5452 p << " : " << getVectorType() << ", " << getShapedType();
5453}
5454
5455LogicalResult TransferWriteOp::verify() {
5456 // Consistency of elemental types in shape and vector.
5457 ShapedType shapedType = getShapedType();
5458 VectorType vectorType = getVectorType();
5459 VectorType maskType = getMaskType();
5460 auto permutationMap = getPermutationMap();
5461 VectorType inferredMaskType =
5462 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5463 : VectorType();
5464
5465 if (llvm::size(getIndices()) != shapedType.getRank())
5466 return emitOpError("requires ") << shapedType.getRank() << " indices";
5467
5468 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5469 // as the semantics is unclear. This can be revisited later if necessary.
5470 if (hasBroadcastDim())
5471 return emitOpError("should not have broadcast dimensions");
5472
5473 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5474 shapedType, vectorType, maskType,
5475 inferredMaskType, permutationMap, getInBounds())))
5476 return failure();
5477
5478 return verifyPermutationMap(permutationMap,
5479 [&](Twine t) { return emitOpError(t); });
5480}
5481
5482//===----------------------------------------------------------------------===//
5483// TransferWriteOp: MaskableOpInterface methods.
5484//===----------------------------------------------------------------------===//
5485
5486/// Returns the mask type expected by this operation. Mostly used for
5487/// verification purposes.
5488Type TransferWriteOp::getExpectedMaskType() {
5489 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5490}
5491
5492//===----------------------------------------------------------------------===//
5493// TransferWriteOp: VectorTransferOpInterface methods.
5494//===----------------------------------------------------------------------===//
5495Value TransferWriteOp::getVector() { return getOperand(0); }
5496VectorType TransferWriteOp::getVectorType() {
5497 return cast<VectorType>(getValueToStore().getType());
5498}
5499
5500//===----------------------------------------------------------------------===//
5501// TransferWriteOp: fold methods.
5502//===----------------------------------------------------------------------===//
5503/// Fold:
5504/// ```
5505/// %t1 = ...
5506/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5507/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5508/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5509/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5510/// ```
5511///
5512/// into:
5513///
5514/// ```
5515/// %t0
5516/// ```
5517///
5518/// The producer of t1 may or may not be DCE'd depending on whether it is a
5519/// block argument or has side effects.
5520static LogicalResult foldReadInitWrite(TransferWriteOp write,
5521 ArrayRef<Attribute>,
5522 SmallVectorImpl<OpFoldResult> &results) {
5523 // TODO: support 0-d corner case.
5524 if (write.getTransferRank() == 0)
5525 return failure();
5526 auto rankedTensorType =
5527 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5528 // If not operating on tensors, bail.
5529 if (!rankedTensorType)
5530 return failure();
5531 // If no read, bail.
5532 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5533 if (!read)
5534 return failure();
5535 // TODO: support 0-d corner case.
5536 if (read.getTransferRank() == 0)
5537 return failure();
5538 // For now, only accept minor identity. Future: composition is minor identity.
5539 if (!read.getPermutationMap().isMinorIdentity() ||
5540 !write.getPermutationMap().isMinorIdentity())
5541 return failure();
5542 // Bail on mismatching ranks.
5543 if (read.getTransferRank() != write.getTransferRank())
5544 return failure();
5545 // Bail on potential out-of-bounds accesses.
5546 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5547 return failure();
5548 // Tensor types must be the same.
5549 if (read.getBase().getType() != rankedTensorType)
5550 return failure();
5551 // Vector types must be the same.
5552 if (read.getVectorType() != write.getVectorType())
5553 return failure();
5554 // Vector and Tensor shapes must match.
5555 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5556 return failure();
5557 // If any index is nonzero.
5558 auto isNotConstantZero = [](Value v) {
5559 auto cstOp = getConstantIntValue(v);
5560 return !cstOp.has_value() || cstOp.value() != 0;
5561 };
5562 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5563 llvm::any_of(write.getIndices(), isNotConstantZero))
5564 return failure();
5565 // Success.
5566 results.push_back(read.getBase());
5567 return success();
5568}
5569
5570static bool checkSameValueWAR(vector::TransferReadOp read,
5571 vector::TransferWriteOp write) {
5572 return read.getBase() == write.getBase() &&
5573 read.getIndices() == write.getIndices() &&
5574 read.getPermutationMap() == write.getPermutationMap() &&
5575 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5576 !write.getMask();
5577}
5578/// Fold transfer_write write after read:
5579/// ```
5580/// %t0 = ...
5581/// %v = vector.transfer_read %t0[%c0...] :
5582/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5583/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5584/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5585/// ```
5586///
5587/// into:
5588///
5589/// ```
5590/// %t0
5591/// ```
5592static LogicalResult foldWAR(TransferWriteOp write,
5593 SmallVectorImpl<OpFoldResult> &results) {
5594 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5595 return failure();
5596 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5597 if (!read)
5598 return failure();
5599
5600 if (!checkSameValueWAR(read, write))
5601 return failure();
5602 results.push_back(read.getBase());
5603 return success();
5604}
5605
5606LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5607 SmallVectorImpl<OpFoldResult> &results) {
5608 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5609 return success();
5610 if (succeeded(foldWAR(*this, results)))
5611 return success();
5612 if (succeeded(foldTransferInBoundsAttribute(*this)))
5613 return success();
5614 if (succeeded(foldTransferFullMask(*this)))
5615 return success();
5616 return memref::foldMemRefCast(*this);
5617}
5618
5619//===----------------------------------------------------------------------===//
5620// TransferWriteOp: other methods.
5621//===----------------------------------------------------------------------===//
5622std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5623 return llvm::to_vector<4>(getVectorType().getShape());
5624}
5625
5626void TransferWriteOp::getEffects(
5627 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5628 &effects) {
5629 if (llvm::isa<MemRefType>(getShapedType()))
5630 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5631 SideEffects::DefaultResource::get());
5632}
5633
5634Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5635 if (hasPureTensorSemantics())
5638}
5639
5640namespace {
5641/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5642/// DCE
5643/// ```
5644/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5645/// : vector<1x4xf32>, tensor<4x4xf32>
5646/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5647/// : vector<1x4xf32>, tensor<4x4xf32>
5648/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5649/// : vector<1x4xf32>, tensor<4x4xf32>
5650/// ```
5651///
5652/// into:
5653///
5654/// ```
5655/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5656/// : vector<1x4xf32>, tensor<4x4xf32>
5657/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5658/// : vector<1x4xf32>, tensor<4x4xf32>
5659/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5660/// : vector<1x4xf32>, tensor<4x4xf32>
5661/// ```
5662///
5663/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5664/// any other uses.
5665class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5666public:
5667 using Base::Base;
5668 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5669 PatternRewriter &rewriter) const override {
5670 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5671 return failure();
5672 vector::TransferWriteOp writeToModify = writeOp;
5673
5674 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5675 while (defWrite) {
5676 if (checkSameValueWAW(writeOp, defWrite)) {
5677 rewriter.modifyOpInPlace(writeToModify, [&]() {
5678 writeToModify.getBaseMutable().assign(defWrite.getBase());
5679 });
5680 return success();
5681 }
5683 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5684 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5685 break;
5686 // If the previous write op doesn't have any other use we an safely look
5687 // at the previous store to see if it can be removed.
5688 if (!defWrite->hasOneUse())
5689 break;
5690 writeToModify = defWrite;
5691 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5692 }
5693 return failure();
5694 }
5695};
5696
5697/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5698/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5699/// overwritten and inserted into another tensor. After this rewrite, the
5700/// operations bufferize in-place since all of them work on the same slice.
5701///
5702/// For example:
5703/// ```mlir
5704/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5705/// : vector<8x16xf32>, tensor<8x16xf32>
5706/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5707/// : tensor<8x16xf32> to tensor<?x?xf32>
5708/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5709/// : tensor<?x?xf32> into tensor<27x37xf32>
5710/// ```
5711/// folds to
5712/// ```mlir
5713/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5714/// : tensor<27x37xf32> to tensor<?x?xf32>
5715/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5716/// : vector<8x16xf32>, tensor<?x?xf32>
5717/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5718/// : tensor<?x?xf32> into tensor<27x37xf32>
5719/// ```
5720struct SwapExtractSliceOfTransferWrite
5721 : public OpRewritePattern<tensor::InsertSliceOp> {
5722public:
5723 using Base::Base;
5724
5725 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5726 PatternRewriter &rewriter) const override {
5727 if (!insertOp.hasUnitStride())
5728 return failure();
5729 auto extractOp =
5730 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5731 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5732 return failure();
5733 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5734 if (!transferOp || !transferOp->hasOneUse())
5735 return failure();
5736
5737 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5738 // rank-reducing.
5739 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5740 return rewriter.notifyMatchFailure(insertOp,
5741 "use-def chain is rank-reducing");
5742 }
5743
5744 // Fail if tensor::ExtractSliceOp has non-zero offset.
5745 if (!extractOp.hasZeroOffset()) {
5746 return rewriter.notifyMatchFailure(insertOp,
5747 "ExtractSliceOp has non-zero offset");
5748 }
5749
5750 // Fail if tensor::TransferWriteOp has non-zero offset.
5751 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5752 return getConstantIntValue(value) == static_cast<int64_t>(0);
5753 })) {
5754 return rewriter.notifyMatchFailure(insertOp,
5755 "TranferWriteOp has non-zero offset");
5756 }
5757
5758 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5759 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5760 return rewriter.notifyMatchFailure(
5761 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5762 }
5763
5764 for (auto [insertSize, extractSize] :
5765 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5766 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5767 return rewriter.notifyMatchFailure(
5768 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5769 }
5770 }
5771
5772 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5773 assert(transferOp.getVectorType().hasStaticShape() &&
5774 "expected vector to have a static shape");
5775 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5776 SmallVector<int64_t> resultShape = applyPermutationMap(
5777 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5778 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5779 return rewriter.notifyMatchFailure(
5780 insertOp, "TransferWriteOp may not write the full tensor.");
5781 }
5782
5783 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5784 // Set all in_bounds to false and let the folder infer them.
5785 SmallVector<bool> newInBounds(vectorShape.size(), false);
5786 auto newExtractOp = tensor::ExtractSliceOp::create(
5787 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5788 insertOp.getDest(), insertOp.getMixedOffsets(),
5789 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5790 auto newTransferWriteOp = TransferWriteOp::create(
5791 rewriter, transferOp.getLoc(), transferOp.getVector(),
5792 newExtractOp.getResult(), transferOp.getIndices(),
5793 transferOp.getPermutationMapAttr(),
5794 rewriter.getBoolArrayAttr(newInBounds));
5795 rewriter.modifyOpInPlace(insertOp, [&]() {
5796 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5797 });
5798 return success();
5799 }
5800};
5801
5802} // namespace
5803
5804void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5805 MLIRContext *context) {
5806 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5807}
5808
5809FailureOr<std::optional<SmallVector<Value>>>
5810TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5811 if (!hasPureBufferSemantics())
5812 return failure();
5814 ValueRange());
5815}
5816
5817//===----------------------------------------------------------------------===//
5818// LoadOp
5819//===----------------------------------------------------------------------===//
5820
5821static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5822 VectorType vecTy,
5823 MemRefType memRefTy) {
5824 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5825 // need any strides limitations.
5826 if (!vecTy.isScalable() &&
5827 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5828 return success();
5829
5830 if (!memRefTy.isLastDimUnitStride())
5831 return op->emitOpError("most minor memref dim must have unit stride");
5832 return success();
5833}
5834
5835LogicalResult vector::LoadOp::verify() {
5836 VectorType resVecTy = getVectorType();
5837 MemRefType memRefTy = getMemRefType();
5838
5839 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5840 return failure();
5841
5842 if (memRefTy.getRank() < resVecTy.getRank())
5843 return emitOpError(
5844 "destination memref has lower rank than the result vector");
5845
5846 // Checks for vector memrefs.
5847 Type memElemTy = memRefTy.getElementType();
5848 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5849 if (memVecTy != resVecTy)
5850 return emitOpError("base memref and result vector types should match");
5851 memElemTy = memVecTy.getElementType();
5852 }
5853
5854 if (resVecTy.getElementType() != memElemTy)
5855 return emitOpError("base and result element types should match");
5856 if (llvm::size(getIndices()) != memRefTy.getRank())
5857 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5858 return success();
5859}
5860
5861OpFoldResult LoadOp::fold(FoldAdaptor) {
5862 if (succeeded(memref::foldMemRefCast(*this)))
5863 return getResult();
5864 return OpFoldResult();
5865}
5866
5867std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5868 return llvm::to_vector<4>(getVectorType().getShape());
5869}
5870
5871FailureOr<std::optional<SmallVector<Value>>>
5872LoadOp::bubbleDownCasts(OpBuilder &builder) {
5874 getResult());
5875}
5876
5877//===----------------------------------------------------------------------===//
5878// StoreOp
5879//===----------------------------------------------------------------------===//
5880
5881LogicalResult vector::StoreOp::verify() {
5882 VectorType valueVecTy = getVectorType();
5883 MemRefType memRefTy = getMemRefType();
5884
5885 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5886 return failure();
5887
5888 if (memRefTy.getRank() < valueVecTy.getRank())
5889 return emitOpError("source memref has lower rank than the vector to store");
5890
5891 // Checks for vector memrefs.
5892 Type memElemTy = memRefTy.getElementType();
5893 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5894 if (memVecTy != valueVecTy)
5895 return emitOpError(
5896 "base memref and valueToStore vector types should match");
5897 memElemTy = memVecTy.getElementType();
5898 }
5899
5900 if (valueVecTy.getElementType() != memElemTy)
5901 return emitOpError("base and valueToStore element type should match");
5902 if (llvm::size(getIndices()) != memRefTy.getRank())
5903 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5904 return success();
5905}
5906
5907LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5908 SmallVectorImpl<OpFoldResult> &results) {
5909 return memref::foldMemRefCast(*this);
5910}
5911
5912std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5913 return llvm::to_vector<4>(getVectorType().getShape());
5914}
5915
5916FailureOr<std::optional<SmallVector<Value>>>
5917StoreOp::bubbleDownCasts(OpBuilder &builder) {
5919 ValueRange());
5920}
5921
5922//===----------------------------------------------------------------------===//
5923// MaskedLoadOp
5924//===----------------------------------------------------------------------===//
5925
5926LogicalResult MaskedLoadOp::verify() {
5927 VectorType maskVType = getMaskVectorType();
5928 VectorType passVType = getPassThruVectorType();
5929 VectorType resVType = getVectorType();
5930 MemRefType memType = getMemRefType();
5931
5932 if (resVType.getElementType() != memType.getElementType())
5933 return emitOpError("base and result element type should match");
5934 if (llvm::size(getIndices()) != memType.getRank())
5935 return emitOpError("requires ") << memType.getRank() << " indices";
5936 if (resVType.getShape() != maskVType.getShape())
5937 return emitOpError("expected result shape to match mask shape");
5938 if (resVType != passVType)
5939 return emitOpError("expected pass_thru of same type as result type");
5940 return success();
5941}
5942
5943namespace {
5944class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5945public:
5946 using Base::Base;
5947 LogicalResult matchAndRewrite(MaskedLoadOp load,
5948 PatternRewriter &rewriter) const override {
5949 switch (getMaskFormat(load.getMask())) {
5951 rewriter.replaceOpWithNewOp<vector::LoadOp>(
5952 load, load.getType(), load.getBase(), load.getIndices());
5953 return success();
5955 rewriter.replaceOp(load, load.getPassThru());
5956 return success();
5958 return failure();
5959 }
5960 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5961 }
5962};
5963} // namespace
5964
5965void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5966 MLIRContext *context) {
5967 results.add<MaskedLoadFolder>(context);
5968}
5969
5970OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5971 if (succeeded(memref::foldMemRefCast(*this)))
5972 return getResult();
5973 return OpFoldResult();
5974}
5975
5976FailureOr<std::optional<SmallVector<Value>>>
5977MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
5979 getResult());
5980}
5981
5982//===----------------------------------------------------------------------===//
5983// MaskedStoreOp
5984//===----------------------------------------------------------------------===//
5985
5986LogicalResult MaskedStoreOp::verify() {
5987 VectorType maskVType = getMaskVectorType();
5988 VectorType valueVType = getVectorType();
5989 MemRefType memType = getMemRefType();
5990
5991 if (valueVType.getElementType() != memType.getElementType())
5992 return emitOpError("base and valueToStore element type should match");
5993 if (llvm::size(getIndices()) != memType.getRank())
5994 return emitOpError("requires ") << memType.getRank() << " indices";
5995 if (valueVType.getShape() != maskVType.getShape())
5996 return emitOpError("expected valueToStore shape to match mask shape");
5997 return success();
5998}
5999
6000namespace {
6001class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
6002public:
6003 using Base::Base;
6004 LogicalResult matchAndRewrite(MaskedStoreOp store,
6005 PatternRewriter &rewriter) const override {
6006 switch (getMaskFormat(store.getMask())) {
6008 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6009 store, store.getValueToStore(), store.getBase(), store.getIndices());
6010 return success();
6012 rewriter.eraseOp(store);
6013 return success();
6015 return failure();
6016 }
6017 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
6018 }
6019};
6020} // namespace
6021
6022void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6023 MLIRContext *context) {
6024 results.add<MaskedStoreFolder>(context);
6025}
6026
6027LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6028 SmallVectorImpl<OpFoldResult> &results) {
6029 return memref::foldMemRefCast(*this);
6030}
6031
6032FailureOr<std::optional<SmallVector<Value>>>
6033MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6035 ValueRange());
6036}
6037
6038//===----------------------------------------------------------------------===//
6039// GatherOp
6040//===----------------------------------------------------------------------===//
6041
6042LogicalResult GatherOp::verify() {
6043 VectorType indVType = getIndexVectorType();
6044 VectorType maskVType = getMaskVectorType();
6045 VectorType resVType = getVectorType();
6046 ShapedType baseType = getBaseType();
6047
6048 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6049 return emitOpError("requires base to be a memref or ranked tensor type");
6050
6051 if (resVType.getElementType() != baseType.getElementType())
6052 return emitOpError("base and result element type should match");
6053 if (llvm::size(getOffsets()) != baseType.getRank())
6054 return emitOpError("requires ") << baseType.getRank() << " indices";
6055 if (resVType.getShape() != indVType.getShape())
6056 return emitOpError("expected result dim to match indices dim");
6057 if (resVType.getShape() != maskVType.getShape())
6058 return emitOpError("expected result dim to match mask dim");
6059 if (resVType != getPassThruVectorType())
6060 return emitOpError("expected pass_thru of same type as result type");
6061 return success();
6062}
6063
6064// MaskableOpInterface methods.
6065
6066/// Returns the mask type expected by this operation. Mostly used for
6067/// verification purposes. It requires the operation to be vectorized."
6068Type GatherOp::getExpectedMaskType() {
6069 auto vecType = this->getIndexVectorType();
6070 return VectorType::get(vecType.getShape(),
6071 IntegerType::get(vecType.getContext(), /*width=*/1),
6072 vecType.getScalableDims());
6073}
6074
6075std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6076 return llvm::to_vector<4>(getVectorType().getShape());
6077}
6078
6079/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
6080static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6081 auto vecType = dyn_cast<VectorType>(indexVec.getType());
6082 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6083 return failure();
6084
6085 if (indexVec.getDefiningOp<StepOp>())
6086 return success();
6087
6088 DenseIntElementsAttr elements;
6089 if (!matchPattern(indexVec, m_Constant(&elements)))
6090 return failure();
6091
6092 return success(
6093 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6094}
6095
6096namespace {
6097class GatherFolder final : public OpRewritePattern<GatherOp> {
6098public:
6099 using Base::Base;
6100 LogicalResult matchAndRewrite(GatherOp gather,
6101 PatternRewriter &rewriter) const override {
6102 switch (getMaskFormat(gather.getMask())) {
6104 return failure(); // no unmasked equivalent
6106 rewriter.replaceOp(gather, gather.getPassThru());
6107 return success();
6109 return failure();
6110 }
6111 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6112 }
6113};
6114
6115/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6116/// maskedload. Only 1D fixed vectors are supported for now.
6117class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6118public:
6119 using Base::Base;
6120 LogicalResult matchAndRewrite(GatherOp op,
6121 PatternRewriter &rewriter) const override {
6122 if (!isa<MemRefType>(op.getBase().getType()))
6123 return rewriter.notifyMatchFailure(op, "base must be of memref type");
6124
6125 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6126 return failure();
6127
6128 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6129 op.getOffsets(), op.getMask(),
6130 op.getPassThru());
6131 return success();
6132 }
6133};
6134} // namespace
6135
6136void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6137 MLIRContext *context) {
6138 results.add<GatherFolder, FoldContiguousGather>(context);
6139}
6140
6141FailureOr<std::optional<SmallVector<Value>>>
6142GatherOp::bubbleDownCasts(OpBuilder &builder) {
6144 getResult());
6145}
6146
6147//===----------------------------------------------------------------------===//
6148// ScatterOp
6149//===----------------------------------------------------------------------===//
6150
6151LogicalResult ScatterOp::verify() {
6152 VectorType indVType = getIndexVectorType();
6153 VectorType maskVType = getMaskVectorType();
6154 VectorType valueVType = getVectorType();
6155 ShapedType baseType = getBaseType();
6156
6157 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6158 return emitOpError("requires base to be a memref or ranked tensor type");
6159
6160 if (valueVType.getElementType() != baseType.getElementType())
6161 return emitOpError("base and valueToStore element type should match");
6162 if (llvm::size(getOffsets()) != baseType.getRank())
6163 return emitOpError("requires ") << baseType.getRank() << " indices";
6164 if (valueVType.getShape() != indVType.getShape())
6165 return emitOpError("expected valueToStore dim to match indices dim");
6166 if (valueVType.getShape() != maskVType.getShape())
6167 return emitOpError("expected valueToStore dim to match mask dim");
6168 return success();
6169}
6170namespace {
6171class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6172public:
6173 using Base::Base;
6174 LogicalResult matchAndRewrite(ScatterOp scatter,
6175 PatternRewriter &rewriter) const override {
6176 ShapedType baseType = scatter.getBaseType();
6177 bool isMemRef = isa<MemRefType>(baseType);
6178 if (!isMemRef && !isa<RankedTensorType>(baseType))
6179 return failure();
6180
6181 // Memrefs have no result, so an all-false mask can simply erase the op.
6182 // Tensors carry the updated value, so we must replace uses with the
6183 // original base tensor instead of erasing.
6184 switch (getMaskFormat(scatter.getMask())) {
6186 return failure(); // no unmasked equivalent
6188 if (isMemRef)
6189 rewriter.eraseOp(scatter);
6190 else
6191 rewriter.replaceOp(scatter, scatter.getBase());
6192 return success();
6194 return failure();
6195 }
6196 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6197 }
6198};
6199
6200/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6201/// maskedstore. Only 1D fixed vectors are supported for now.
6202class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6203public:
6204 using Base::Base;
6205 LogicalResult matchAndRewrite(ScatterOp op,
6206 PatternRewriter &rewriter) const override {
6207 // Fold only for memrefs: the replacement uses maskedstore, which does not
6208 // support tensor bases. Tensor cases intentionally bail out.
6209 if (!isa<MemRefType>(op.getBase().getType()))
6210 return failure();
6211
6212 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6213 return failure();
6214
6215 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6216 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6217 return success();
6218 }
6219};
6220} // namespace
6221
6222void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6223 MLIRContext *context) {
6224 results.add<ScatterFolder, FoldContiguousScatter>(context);
6225}
6226
6227FailureOr<std::optional<SmallVector<Value>>>
6228ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6230 ValueRange());
6231}
6232
6233//===----------------------------------------------------------------------===//
6234// ExpandLoadOp
6235//===----------------------------------------------------------------------===//
6236
6237LogicalResult ExpandLoadOp::verify() {
6238 VectorType maskVType = getMaskVectorType();
6239 VectorType passVType = getPassThruVectorType();
6240 VectorType resVType = getVectorType();
6241 MemRefType memType = getMemRefType();
6242
6243 if (resVType.getElementType() != memType.getElementType())
6244 return emitOpError("base and result element type should match");
6245 if (llvm::size(getIndices()) != memType.getRank())
6246 return emitOpError("requires ") << memType.getRank() << " indices";
6247 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6248 return emitOpError("expected result dim to match mask dim");
6249 if (resVType != passVType)
6250 return emitOpError("expected pass_thru of same type as result type");
6251 return success();
6252}
6253
6254namespace {
6255class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6256public:
6257 using Base::Base;
6258 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6259 PatternRewriter &rewriter) const override {
6260 switch (getMaskFormat(expand.getMask())) {
6262 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6263 expand, expand.getType(), expand.getBase(), expand.getIndices());
6264 return success();
6266 rewriter.replaceOp(expand, expand.getPassThru());
6267 return success();
6269 return failure();
6270 }
6271 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6272 }
6273};
6274} // namespace
6275
6276void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6277 MLIRContext *context) {
6278 results.add<ExpandLoadFolder>(context);
6279}
6280
6281FailureOr<std::optional<SmallVector<Value>>>
6282ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6284 getResult());
6285}
6286
6287//===----------------------------------------------------------------------===//
6288// CompressStoreOp
6289//===----------------------------------------------------------------------===//
6290
6291LogicalResult CompressStoreOp::verify() {
6292 VectorType maskVType = getMaskVectorType();
6293 VectorType valueVType = getVectorType();
6294 MemRefType memType = getMemRefType();
6295
6296 if (valueVType.getElementType() != memType.getElementType())
6297 return emitOpError("base and valueToStore element type should match");
6298 if (llvm::size(getIndices()) != memType.getRank())
6299 return emitOpError("requires ") << memType.getRank() << " indices";
6300 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6301 return emitOpError("expected valueToStore dim to match mask dim");
6302 return success();
6303}
6304
6305namespace {
6306class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6307public:
6308 using Base::Base;
6309 LogicalResult matchAndRewrite(CompressStoreOp compress,
6310 PatternRewriter &rewriter) const override {
6311 switch (getMaskFormat(compress.getMask())) {
6313 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6314 compress, compress.getValueToStore(), compress.getBase(),
6315 compress.getIndices());
6316 return success();
6318 rewriter.eraseOp(compress);
6319 return success();
6321 return failure();
6322 }
6323 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6324 }
6325};
6326} // namespace
6327
6328void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6329 MLIRContext *context) {
6330 results.add<CompressStoreFolder>(context);
6331}
6332
6333FailureOr<std::optional<SmallVector<Value>>>
6334CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6336 ValueRange());
6337}
6338
6339//===----------------------------------------------------------------------===//
6340// ShapeCastOp
6341//===----------------------------------------------------------------------===//
6342
6343void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6344 SetIntRangeFn setResultRanges) {
6345 setResultRanges(getResult(), argRanges.front());
6346}
6347
6348std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6349 return llvm::to_vector<4>(getResultVectorType().getShape());
6350}
6351
6352LogicalResult ShapeCastOp::verify() {
6353
6354 VectorType sourceType = getSourceVectorType();
6355 VectorType resultType = getResultVectorType();
6356
6357 // Check that element type is preserved
6358 if (sourceType.getElementType() != resultType.getElementType())
6359 return emitOpError("has different source and result element types");
6360
6361 // Check that number of elements is preserved
6362 int64_t sourceNElms = sourceType.getNumElements();
6363 int64_t resultNElms = resultType.getNumElements();
6364 if (sourceNElms != resultNElms) {
6365 return emitOpError() << "has different number of elements at source ("
6366 << sourceNElms << ") and result (" << resultNElms
6367 << ")";
6368 }
6369
6370 // Check that (non-)scalability is preserved
6371 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6372 int64_t resultNScalableDims = resultType.getNumScalableDims();
6373 if (sourceNScalableDims != resultNScalableDims)
6374 return emitOpError() << "has different number of scalable dims at source ("
6375 << sourceNScalableDims << ") and result ("
6376 << resultNScalableDims << ")";
6377
6378 return success();
6379}
6380
6381/// Return true if `transpose` does not permute a pair of non-unit dims.
6382/// By `order preserving` we mean that the flattened versions of the input and
6383/// output vectors are (numerically) identical. In other words `transpose` is
6384/// effectively a shape cast.
6385static bool isOrderPreserving(TransposeOp transpose) {
6386 ArrayRef<int64_t> permutation = transpose.getPermutation();
6387 VectorType sourceType = transpose.getSourceVectorType();
6388 ArrayRef<int64_t> inShape = sourceType.getShape();
6389 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6390 auto isNonScalableUnitDim = [&](int64_t dim) {
6391 return inShape[dim] == 1 && !inDimIsScalable[dim];
6392 };
6393 int64_t current = 0;
6394 for (auto p : permutation) {
6395 if (!isNonScalableUnitDim(p)) {
6396 if (p < current) {
6397 return false;
6398 }
6399 current = p;
6400 }
6401 }
6402 return true;
6403}
6404
6405OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6406
6407 VectorType resultType = getType();
6408
6409 // No-op shape cast.
6410 if (getSource().getType() == resultType)
6411 return getSource();
6412
6413 // shape_cast(shape_cast(x)) -> shape_cast(x)
6414 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6415 setOperand(precedingShapeCast.getSource());
6416 return getResult();
6417 }
6418
6419 // shape_cast(transpose(x)) -> shape_cast(x)
6420 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6421 if (isOrderPreserving(transpose)) {
6422 setOperand(transpose.getVector());
6423 return getResult();
6424 }
6425 return {};
6426 }
6427
6428 // Y = shape_cast(broadcast(X))
6429 // -> X, if X and Y have same type
6430 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6431 if (bcastOp.getSourceType() == resultType)
6432 return bcastOp.getSource();
6433 }
6434
6435 // shape_cast(constant) -> constant
6436 if (auto denseAttr =
6437 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6438 return denseAttr.reshape(getType());
6439
6440 // shape_cast(poison) -> poison
6441 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6442 return ub::PoisonAttr::get(getContext());
6443
6444 return {};
6445}
6446
6447namespace {
6448
6449/// Helper function that computes a new vector type based on the input vector
6450/// type by removing the trailing one dims:
6451///
6452/// vector<4x1x1xi1> --> vector<4x1xi1>
6453///
6454static VectorType trimTrailingOneDims(VectorType oldType) {
6455 ArrayRef<int64_t> oldShape = oldType.getShape();
6456 ArrayRef<int64_t> newShape = oldShape;
6457
6458 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6459 ArrayRef<bool> newScalableDims = oldScalableDims;
6460
6461 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6462 newShape = newShape.drop_back(1);
6463 newScalableDims = newScalableDims.drop_back(1);
6464 }
6465
6466 // Make sure we have at least 1 dimension.
6467 // TODO: Add support for 0-D vectors.
6468 if (newShape.empty()) {
6469 newShape = oldShape.take_back();
6470 newScalableDims = oldScalableDims.take_back();
6471 }
6472
6473 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6474}
6475
6476/// Folds qualifying shape_cast(create_mask) into a new create_mask
6477///
6478/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6479/// dimension. If the input vector comes from `vector.create_mask` for which
6480/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6481/// to fold shape_cast into create_mask.
6482///
6483/// BEFORE:
6484/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6485/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6486/// AFTER:
6487/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6488class ShapeCastCreateMaskFolderTrailingOneDim final
6489 : public OpRewritePattern<ShapeCastOp> {
6490public:
6491 using Base::Base;
6492
6493 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6494 PatternRewriter &rewriter) const override {
6495 Value shapeOpSrc = shapeOp->getOperand(0);
6496 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6497 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6498 if (!createMaskOp && !constantMaskOp)
6499 return failure();
6500
6501 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6502 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6503
6504 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6505 if (newVecType != shapeOpResTy)
6506 return failure();
6507
6508 auto numDimsToDrop =
6509 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6510
6511 // No unit dims to drop
6512 if (!numDimsToDrop)
6513 return failure();
6514
6515 if (createMaskOp) {
6516 auto maskOperands = createMaskOp.getOperands();
6517 auto numMaskOperands = maskOperands.size();
6518
6519 // Check every mask dim size to see whether it can be dropped
6520 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6521 --i) {
6522 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6523 if (!constant || (constant.value() != 1))
6524 return failure();
6525 }
6526 SmallVector<Value> newMaskOperands =
6527 maskOperands.drop_back(numDimsToDrop);
6528
6529 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6530 newMaskOperands);
6531 return success();
6532 }
6533
6534 if (constantMaskOp) {
6535 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6536 auto numMaskOperands = maskDimSizes.size();
6537
6538 // Check every mask dim size to see whether it can be dropped
6539 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6540 --i) {
6541 if (maskDimSizes[i] != 1)
6542 return failure();
6543 }
6544
6545 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6546 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6547 newMaskOperands);
6548 return success();
6549 }
6550
6551 return failure();
6552 }
6553};
6554
6555/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
6556/// i) Y = ShapeCast(X), or
6557/// ii) Y = Broadcast(X)
6558/// If both (i) and (ii) are possible, (i) is chosen.
6559class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6560public:
6561 using Base::Base;
6562
6563 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6564 PatternRewriter &rewriter) const override {
6565 auto broadcastOp =
6566 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6567 if (!broadcastOp)
6568 return failure();
6569
6570 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6571 bool srcIsScalar = !srcVectorType;
6572
6573 // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
6574 // Example:
6575 // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
6576 // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
6577 // to
6578 // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
6579 if (srcVectorType) {
6580 if (srcVectorType.getNumElements() ==
6581 shapeCastOp.getResultVectorType().getNumElements()) {
6582 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6583 shapeCastOp, shapeCastOp.getResultVectorType(),
6584 broadcastOp.getSource());
6585 return success();
6586 }
6587 }
6588
6589 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6590 // Example
6591 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6592 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6593 // to
6594 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6595 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6596 if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
6597 BroadcastableToResult::Success) {
6598 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6599 shapeCastOp, dstVectorType, broadcastOp.getSource());
6600 return success();
6601 }
6602 return failure();
6603 }
6604};
6605
6606} // namespace
6607
6608void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6609 MLIRContext *context) {
6610 results
6611 .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6612 context);
6613}
6614
6615//===----------------------------------------------------------------------===//
6616// VectorBitCastOp
6617//===----------------------------------------------------------------------===//
6618
6619LogicalResult BitCastOp::verify() {
6620 auto sourceVectorType = getSourceVectorType();
6621 auto resultVectorType = getResultVectorType();
6622
6623 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6624 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6625 return emitOpError("dimension size mismatch at: ") << i;
6626 }
6627
6628 DataLayout dataLayout = DataLayout::closest(*this);
6629 auto sourceElementBits =
6630 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
6631 auto resultElementBits =
6632 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
6633
6634 if (sourceVectorType.getRank() == 0) {
6635 if (sourceElementBits != resultElementBits)
6636 return emitOpError("source/result bitwidth of the 0-D vector element "
6637 "types must be equal");
6638 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6639 resultElementBits * resultVectorType.getShape().back()) {
6640 return emitOpError(
6641 "source/result bitwidth of the minor 1-D vectors must be equal");
6642 }
6643
6644 return success();
6645}
6646
6647OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6648 // Nop cast.
6649 if (getSource().getType() == getResult().getType())
6650 return getSource();
6651
6652 // Canceling bitcasts.
6653 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6654 if (getResult().getType() == otherOp.getSource().getType())
6655 return otherOp.getSource();
6656
6657 setOperand(otherOp.getSource());
6658 return getResult();
6659 }
6660
6661 Attribute sourceConstant = adaptor.getSource();
6662 if (!sourceConstant)
6663 return {};
6664
6665 Type srcElemType = getSourceVectorType().getElementType();
6666 Type dstElemType = getResultVectorType().getElementType();
6667
6668 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6669 if (floatPack.isSplat()) {
6670 auto splat = floatPack.getSplatValue<FloatAttr>();
6671
6672 // Casting fp16 into fp32.
6673 if (srcElemType.isF16() && dstElemType.isF32()) {
6674 uint32_t bits = static_cast<uint32_t>(
6675 splat.getValue().bitcastToAPInt().getZExtValue());
6676 // Duplicate the 16-bit pattern.
6677 bits = (bits << 16) | (bits & 0xffff);
6678 APInt intBits(32, bits);
6679 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6680 return DenseElementsAttr::get(getResultVectorType(), floatBits);
6681 }
6682 }
6683 }
6684
6685 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6686 if (intPack.isSplat()) {
6687 auto splat = intPack.getSplatValue<IntegerAttr>();
6688
6689 if (llvm::isa<IntegerType>(dstElemType)) {
6690 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6691 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6692
6693 // Casting to a larger integer bit width.
6694 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6695 APInt intBits = splat.getValue().zext(dstBitWidth);
6696
6697 // Duplicate the lower width element.
6698 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6699 intBits = (intBits << srcBitWidth) | intBits;
6700 return DenseElementsAttr::get(getResultVectorType(), intBits);
6701 }
6702 }
6703 }
6704 }
6705
6706 return {};
6707}
6708
6709//===----------------------------------------------------------------------===//
6710// TypeCastOp
6711//===----------------------------------------------------------------------===//
6712
6713static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6714 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6715 SmallVector<int64_t, 8> res(memRefType.getShape());
6716 if (vectorType)
6717 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6718 return res;
6719}
6720
6721/// Build the canonical memRefType with a single vector.
6722/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6723void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6724 Value source) {
6725 result.addOperands(source);
6726 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
6727 VectorType vectorType =
6728 VectorType::get(extractShape(memRefType),
6730 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6731 memRefType.getMemorySpace()));
6732}
6733
6734LogicalResult TypeCastOp::verify() {
6735 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6736 if (!canonicalType.getLayout().isIdentity())
6737 return emitOpError("expects operand to be a memref with identity layout");
6738 if (!getResultMemRefType().getLayout().isIdentity())
6739 return emitOpError("expects result to be a memref with identity layout");
6740 if (getResultMemRefType().getMemorySpace() !=
6741 getMemRefType().getMemorySpace())
6742 return emitOpError("expects result in same memory space");
6743
6744 auto sourceType = getMemRefType();
6745 auto resultType = getResultMemRefType();
6746 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
6748 return emitOpError(
6749 "expects result and operand with same underlying scalar type: ")
6750 << resultType;
6751 if (extractShape(sourceType) != extractShape(resultType))
6752 return emitOpError(
6753 "expects concatenated result and operand shapes to be equal: ")
6754 << resultType;
6755 return success();
6756}
6757
6758//===----------------------------------------------------------------------===//
6759// TransposeOp
6760//===----------------------------------------------------------------------===//
6761
6762void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6763 Value vector, ArrayRef<int64_t> permutation) {
6764 VectorType vt = llvm::cast<VectorType>(vector.getType());
6765 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6766 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6767 for (unsigned i = 0; i < permutation.size(); ++i) {
6768 transposedShape[i] = vt.getShape()[permutation[i]];
6769 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6770 }
6771
6772 result.addOperands(vector);
6773 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6774 transposedScalableDims));
6775 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6776 builder.getDenseI64ArrayAttr(permutation));
6777}
6778
6779OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6780 // Eliminate splat constant transpose ops.
6781 if (auto splat =
6782 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6783 return splat.reshape(getResultVectorType());
6784
6785 // Eliminate poison transpose ops.
6786 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6787 return ub::PoisonAttr::get(getContext());
6788
6789 // Eliminate identity transposes, and more generally any transposes that
6790 // preserves the shape without permuting elements.
6791 //
6792 // Examples of what to fold:
6793 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6794 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6795 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6796 //
6797 // Example of what NOT to fold:
6798 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6799 //
6800 if (getSourceVectorType() == getResultVectorType() &&
6801 isOrderPreserving(*this))
6802 return getVector();
6803
6804 return {};
6805}
6806
6807LogicalResult vector::TransposeOp::verify() {
6808 VectorType vectorType = getSourceVectorType();
6809 VectorType resultType = getResultVectorType();
6810 int64_t rank = resultType.getRank();
6811 if (vectorType.getRank() != rank)
6812 return emitOpError("vector result rank mismatch: ") << rank;
6813 // Verify transposition array.
6814 ArrayRef<int64_t> perm = getPermutation();
6815 int64_t size = perm.size();
6816 if (rank != size)
6817 return emitOpError("transposition length mismatch: ") << size;
6818 SmallVector<bool, 8> seen(rank, false);
6819 for (const auto &ta : llvm::enumerate(perm)) {
6820 if (ta.value() < 0 || ta.value() >= rank)
6821 return emitOpError("transposition index out of range: ") << ta.value();
6822 if (seen[ta.value()])
6823 return emitOpError("duplicate position index: ") << ta.value();
6824 seen[ta.value()] = true;
6825 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6826 return emitOpError("dimension size mismatch at: ") << ta.value();
6827 }
6828 return success();
6829}
6830
6831std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6832 return llvm::to_vector<4>(getResultVectorType().getShape());
6833}
6834
6835void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6836 SetIntRangeFn setResultRanges) {
6837 setResultRanges(getResult(), argRanges.front());
6838}
6839
6840namespace {
6841
6842// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6843class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6844public:
6845 using Base::Base;
6846
6847 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6848 PatternRewriter &rewriter) const override {
6849 // Composes two permutations: result[i] = permutation1[permutation2[i]].
6850 auto composePermutations = [](ArrayRef<int64_t> permutation1,
6851 ArrayRef<int64_t> permutation2) {
6852 SmallVector<int64_t, 4> result;
6853 for (auto index : permutation2)
6854 result.push_back(permutation1[index]);
6855 return result;
6856 };
6857
6858 // Return if the input of 'transposeOp' is not defined by another transpose.
6859 vector::TransposeOp parentTransposeOp =
6860 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6861 if (!parentTransposeOp)
6862 return failure();
6863
6864 SmallVector<int64_t, 4> permutation = composePermutations(
6865 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6866 // Replace 'transposeOp' with a new transpose operation.
6867 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6868 transposeOp, transposeOp.getResult().getType(),
6869 parentTransposeOp.getVector(), permutation);
6870 return success();
6871 }
6872};
6873
6874/// Replace transpose(splat-like(v)) with broadcast(v)
6875class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6876public:
6877 using Base::Base;
6878
6879 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6880 PatternRewriter &rewriter) const override {
6881 Value splat = getScalarSplatSource(transposeOp.getVector());
6882 if (!splat)
6883 return failure();
6884
6885 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6886 transposeOp, transposeOp.getResultVectorType(), splat);
6887 return success();
6888 }
6889};
6890
6891/// Folds transpose(create_mask) into a new transposed create_mask.
6892class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6893public:
6894 using Base::Base;
6895
6896 LogicalResult matchAndRewrite(TransposeOp transpOp,
6897 PatternRewriter &rewriter) const override {
6898 Value transposeSrc = transpOp.getVector();
6899 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6900 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6901 if (!createMaskOp && !constantMaskOp)
6902 return failure();
6903
6904 // Get the transpose permutation and apply it to the vector.create_mask or
6905 // vector.constant_mask operands.
6906 ArrayRef<int64_t> permutation = transpOp.getPermutation();
6907
6908 if (createMaskOp) {
6909 auto maskOperands = createMaskOp.getOperands();
6910 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6911 applyPermutationToVector(newOperands, permutation);
6912
6913 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6914 transpOp, transpOp.getResultVectorType(), newOperands);
6915 return success();
6916 }
6917
6918 // ConstantMaskOp case.
6919 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6920 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6921
6922 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6923 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6924 return success();
6925 }
6926};
6927
6928/// Folds transpose(shape_cast) into a new shape_cast.
6929class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6930public:
6931 using Base::Base;
6932
6933 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6934 PatternRewriter &rewriter) const override {
6935 auto shapeCastOp =
6936 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6937 if (!shapeCastOp)
6938 return failure();
6939 if (!isOrderPreserving(transposeOp))
6940 return failure();
6941
6942 VectorType resultType = transposeOp.getType();
6943
6944 // We don't need to check isValidShapeCast at this point, because it is
6945 // guaranteed that merging the transpose into the the shape_cast is a valid
6946 // shape_cast, because the transpose just inserts/removes ones.
6947
6948 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6949 shapeCastOp.getSource());
6950 return success();
6951 }
6952};
6953
6954/// Folds transpose(from_elements(...)) into a new from_elements with permuted
6955/// operands matching the transposed shape.
6956///
6957/// Example:
6958///
6959/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
6960/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
6961/// vector<3x2xi32>
6962///
6963/// becomes ->
6964///
6965/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
6966/// vector<3x2xi32>
6967///
6968class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
6969public:
6970 using Base::Base;
6971 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6972 PatternRewriter &rewriter) const override {
6973 auto fromElementsOp =
6974 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
6975 if (!fromElementsOp)
6976 return failure();
6977
6978 VectorType srcTy = fromElementsOp.getDest().getType();
6979 VectorType dstTy = transposeOp.getType();
6980
6981 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
6982 int64_t rank = srcTy.getRank();
6983
6984 // Build inverse permutation to map destination indices back to source.
6985 SmallVector<int64_t> inversePerm(rank, 0);
6986 for (int64_t i = 0; i < rank; ++i)
6987 inversePerm[permutation[i]] = i;
6988
6989 ArrayRef<int64_t> srcShape = srcTy.getShape();
6990 ArrayRef<int64_t> dstShape = dstTy.getShape();
6991 SmallVector<int64_t> srcIdx(rank, 0);
6992 SmallVector<int64_t> dstIdx(rank, 0);
6993 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
6994 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
6995
6996 auto elementsOld = fromElementsOp.getElements();
6997 SmallVector<Value> elementsNew;
6998 int64_t dstNumElements = dstTy.getNumElements();
6999 elementsNew.reserve(dstNumElements);
7000
7001 // For each element in destination row-major order, pick the corresponding
7002 // source element.
7003 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7004 // Pick the destination element index.
7005 dstIdx = delinearize(linearIdx, dstStrides);
7006 // Map the destination element index to the source element index.
7007 for (int64_t j = 0; j < rank; ++j)
7008 srcIdx[j] = dstIdx[inversePerm[j]];
7009 // Linearize the source element index.
7010 int64_t srcLin = linearize(srcIdx, srcStrides);
7011 // Add the source element to the new elements.
7012 elementsNew.push_back(elementsOld[srcLin]);
7013 }
7014
7015 rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
7016 elementsNew);
7017 return success();
7018 }
7019};
7020
7021/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
7022/// 'order preserving', where 'order preserving' means the flattened
7023/// inputs and outputs of the transpose have identical (numerical) values.
7024///
7025/// Example:
7026/// ```
7027/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
7028/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
7029/// to vector<8x1xi32>
7030/// ```
7031/// can be rewritten as the equivalent
7032/// ```
7033/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
7034/// ```
7035/// The algorithm works by partitioning dimensions into groups that can be
7036/// locally permuted while preserving order, and checks that the transpose
7037/// only permutes within these groups.
7038///
7039/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
7040/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
7041/// broadcasting from 1x1x4x1x1x7.
7042/// ^^^ ^ ^^^ ^
7043/// groups: 0 1 2 3
7044/// Order preserving permutations for this example are ones that only permute
7045/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
7046class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
7047public:
7048 using Base::Base;
7049 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7050 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7051
7052 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7053 PatternRewriter &rewriter) const override {
7054
7055 vector::BroadcastOp broadcast =
7056 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
7057 if (!broadcast) {
7058 return rewriter.notifyMatchFailure(transpose,
7059 "not preceded by a broadcast");
7060 }
7061
7062 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
7063 VectorType outputType = transpose.getResultVectorType();
7064
7065 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
7066 bool inputIsScalar = !inputType;
7067 if (inputIsScalar) {
7068 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7069 broadcast.getSource());
7070 return success();
7071 }
7072
7073 ArrayRef<int64_t> permutation = transpose.getPermutation();
7074 ArrayRef<int64_t> inputShape = inputType.getShape();
7075 int64_t inputRank = inputType.getRank();
7076 int64_t outputRank = transpose.getType().getRank();
7077 int64_t deltaRank = outputRank - inputRank;
7078
7079 int low = 0;
7080 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7081 bool notOne = inputShape[inputIndex] != 1;
7082 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7083 bool groupEndFound = notOne || prevNotOne;
7084 if (groupEndFound) {
7085 int high = inputIndex + deltaRank;
7086 // Return failure if not all permutation destinations for indices in
7087 // [low, high) are in [low, high), i.e. the permutation is not local to
7088 // the group.
7089 for (int i = low; i < high; ++i) {
7090 if (permutation[i] < low || permutation[i] >= high) {
7091 return rewriter.notifyMatchFailure(
7092 transpose, "permutation not local to group");
7093 }
7094 }
7095 low = high;
7096 }
7097 }
7098
7099 // We don't need to check the final group [low, outputRank) because if it is
7100 // not locally bound, there must be a preceding group that already failed
7101 // the check (impossible to have just 1 non-locally bound group).
7102
7103 // The preceding logic also ensures that at this point, the output of the
7104 // transpose is definitely broadcastable from the input shape, assert so:
7105 assert(vector::isBroadcastableTo(inputType, outputType) ==
7106 vector::BroadcastableToResult::Success &&
7107 "not broadcastable directly to transpose output");
7108
7109 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7110 broadcast.getSource());
7111
7112 return success();
7113 }
7114};
7115
7116} // namespace
7117
7118void vector::TransposeOp::getCanonicalizationPatterns(
7119 RewritePatternSet &results, MLIRContext *context) {
7120 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7121 FoldTransposeSplat, FoldTransposeFromElements,
7122 FoldTransposeBroadcast>(context);
7123}
7124
7125//===----------------------------------------------------------------------===//
7126// ConstantMaskOp
7127//===----------------------------------------------------------------------===//
7128
7129void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
7130 VectorType type, ConstantMaskKind kind) {
7131 assert(kind == ConstantMaskKind::AllTrue ||
7132 kind == ConstantMaskKind::AllFalse);
7133 build(builder, result, type,
7134 kind == ConstantMaskKind::AllTrue
7135 ? type.getShape()
7136 : SmallVector<int64_t>(type.getRank(), 0));
7137}
7138
7139LogicalResult ConstantMaskOp::verify() {
7140 auto resultType = llvm::cast<VectorType>(getResult().getType());
7141 // Check the corner case of 0-D vectors first.
7142 if (resultType.getRank() == 0) {
7143 if (getMaskDimSizes().size() != 1)
7144 return emitError("array attr must have length 1 for 0-D vectors");
7145 auto dim = getMaskDimSizes()[0];
7146 if (dim != 0 && dim != 1)
7147 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
7148 return success();
7149 }
7150
7151 // Verify that array attr size matches the rank of the vector result.
7152 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
7153 return emitOpError(
7154 "must specify array attr of size equal vector result rank");
7155 // Verify that each array attr element is in bounds of corresponding vector
7156 // result dimension size.
7157 auto resultShape = resultType.getShape();
7158 auto resultScalableDims = resultType.getScalableDims();
7159 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7160 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7161 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7162 return emitOpError(
7163 "array attr of size out of bounds of vector result dimension size");
7164 if (resultScalableDims[index] && maskDimSize != 0 &&
7165 maskDimSize != resultShape[index])
7166 return emitOpError(
7167 "only supports 'none set' or 'all set' scalable dimensions");
7168 }
7169 // Verify that if one mask dim size is zero, they all should be zero (because
7170 // the mask region is a conjunction of each mask dimension interval).
7171 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7172 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
7173 if (anyZeros && !allZeros)
7174 return emitOpError("expected all mask dim sizes to be zeros, "
7175 "as a result of conjunction with zero mask dim");
7176 return success();
7177}
7178
7179bool ConstantMaskOp::isAllOnesMask() {
7180 auto resultType = getVectorType();
7181 // Check the corner case of 0-D vectors first.
7182 if (resultType.getRank() == 0) {
7183 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7184 return getMaskDimSizes()[0] == 1;
7185 }
7186 for (const auto [resultSize, maskDimSize] :
7187 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7188 if (maskDimSize < resultSize)
7189 return false;
7190 }
7191 return true;
7192}
7193
7194OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7195 ArrayRef<int64_t> bounds = getMaskDimSizes();
7196 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7197
7198 auto createBoolSplat = [&](bool x) {
7199 return SplatElementsAttr::get(getVectorType(),
7201 };
7202
7203 // Check the corner case of 0-D vectors first.
7204 if (vectorSizes.empty()) {
7205 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7206 return createBoolSplat(bounds[0] == 1);
7207 }
7208 // Fold vector.constant_mask to splat if possible.
7209 if (bounds == vectorSizes)
7210 return createBoolSplat(true);
7211 if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7212 return createBoolSplat(false);
7213 return OpFoldResult();
7214}
7215
7216//===----------------------------------------------------------------------===//
7217// CreateMaskOp
7218//===----------------------------------------------------------------------===//
7219
7220void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7221 VectorType type,
7222 ArrayRef<OpFoldResult> mixedOperands) {
7223 SmallVector<Value> operands =
7224 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7225 build(builder, result, type, operands);
7226}
7227
7228LogicalResult CreateMaskOp::verify() {
7229 auto vectorType = llvm::cast<VectorType>(getResult().getType());
7230 // Verify that an operand was specified for each result vector each dimension.
7231 if (vectorType.getRank() == 0) {
7232 if (getNumOperands() != 1)
7233 return emitOpError(
7234 "must specify exactly one operand for 0-D create_mask");
7235 } else if (getNumOperands() !=
7236 llvm::cast<VectorType>(getResult().getType()).getRank()) {
7237 return emitOpError(
7238 "must specify an operand for each result vector dimension");
7239 }
7240 return success();
7241}
7242
7243namespace {
7244
7245/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7246///
7247/// Ex 1:
7248/// %c2 = arith.constant 2 : index
7249/// %c3 = arith.constant 3 : index
7250/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7251/// Becomes:
7252/// vector.constant_mask [3, 2] : vector<4x3xi1>
7253///
7254/// Ex 2:
7255/// %c_neg_1 = arith.constant -1 : index
7256/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7257/// becomes:
7258/// vector.constant_mask [0] : vector<[8]xi1>
7259///
7260/// Ex 3:
7261/// %c8 = arith.constant 8 : index
7262/// %c16 = arith.constant 16 : index
7263/// %0 = vector.vscale
7264/// %1 = arith.muli %0, %c16 : index
7265/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7266/// becomes:
7267/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7268class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7269public:
7270 using Base::Base;
7271
7272 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7273 PatternRewriter &rewriter) const override {
7274 VectorType maskType = createMaskOp.getVectorType();
7275 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7276 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7277
7278 // Special case: Rank zero shape.
7279 constexpr std::array<int64_t, 1> rankZeroShape{1};
7280 constexpr std::array<bool, 1> rankZeroScalableDims{false};
7281 if (maskType.getRank() == 0) {
7282 maskTypeDimSizes = rankZeroShape;
7283 maskTypeDimScalableFlags = rankZeroScalableDims;
7284 }
7285
7286 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7287 // collect the `constantDims` (for the ConstantMaskOp).
7288 SmallVector<int64_t, 4> constantDims;
7289 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7290 if (auto intSize = getConstantIntValue(dimSize)) {
7291 // Constant value.
7292 // If the mask dim is non-scalable this can be any value.
7293 // If the mask dim is scalable only zero (all-false) is supported.
7294 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7295 return failure();
7296 constantDims.push_back(*intSize);
7297 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7298 // Constant vscale multiple (e.g. 4 x vscale).
7299 // Must be all-true to fold to a ConstantMask.
7300 if (vscaleMultiplier < maskTypeDimSizes[i])
7301 return failure();
7302 constantDims.push_back(*vscaleMultiplier);
7303 } else {
7304 return failure();
7305 }
7306 }
7307
7308 // Clamp values to constant_mask bounds.
7309 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7310 value = std::clamp<int64_t>(value, 0, maskDimSize);
7311
7312 // If one of dim sizes is zero, set all dims to zero.
7313 if (llvm::is_contained(constantDims, 0))
7314 constantDims.assign(constantDims.size(), 0);
7315
7316 // Replace 'createMaskOp' with ConstantMaskOp.
7317 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7318 constantDims);
7319 return success();
7320 }
7321};
7322
7323} // namespace
7324
7325void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7326 MLIRContext *context) {
7327 results.add<CreateMaskFolder>(context);
7328}
7329
7330//===----------------------------------------------------------------------===//
7331// MaskOp
7332//===----------------------------------------------------------------------===//
7333
7334void MaskOp::build(
7335 OpBuilder &builder, OperationState &result, Value mask,
7336 Operation *maskableOp,
7337 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7338 assert(maskRegionBuilder &&
7339 "builder callback for 'maskRegion' must be present");
7340
7341 result.addOperands(mask);
7342 OpBuilder::InsertionGuard guard(builder);
7343 Region *maskRegion = result.addRegion();
7344 builder.createBlock(maskRegion);
7345 maskRegionBuilder(builder, maskableOp);
7346}
7347
7348void MaskOp::build(
7349 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7350 Value mask, Operation *maskableOp,
7351 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7352 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7353 maskRegionBuilder);
7354}
7355
7356void MaskOp::build(
7357 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7358 Value mask, Value passthru, Operation *maskableOp,
7359 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7360 build(builder, result, mask, maskableOp, maskRegionBuilder);
7361 if (passthru)
7362 result.addOperands(passthru);
7363 result.addTypes(resultTypes);
7364}
7365
7366ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7367 // Create the op region.
7368 result.regions.reserve(1);
7369 Region &maskRegion = *result.addRegion();
7370
7371 auto &builder = parser.getBuilder();
7372
7373 // Parse all the operands.
7374 OpAsmParser::UnresolvedOperand mask;
7375 if (parser.parseOperand(mask))
7376 return failure();
7377
7378 // Optional passthru operand.
7379 OpAsmParser::UnresolvedOperand passthru;
7380 ParseResult parsePassthru = parser.parseOptionalComma();
7381 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7382 return failure();
7383
7384 // Parse op region.
7385 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7386 return failure();
7387
7388 MaskOp::ensureTerminator(maskRegion, builder, result.location);
7389
7390 // Parse the optional attribute list.
7391 if (parser.parseOptionalAttrDict(result.attributes))
7392 return failure();
7393
7394 // Parse all the types.
7395 Type maskType;
7396 if (parser.parseColonType(maskType))
7397 return failure();
7398
7399 SmallVector<Type> resultTypes;
7400 if (parser.parseOptionalArrowTypeList(resultTypes))
7401 return failure();
7402 result.types.append(resultTypes);
7403
7404 // Resolve operands.
7405 if (parser.resolveOperand(mask, maskType, result.operands))
7406 return failure();
7407
7408 if (parsePassthru.succeeded()) {
7409 if (resultTypes.empty())
7410 return parser.emitError(
7411 parser.getNameLoc(),
7412 "expects a result if passthru operand is provided");
7413
7414 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7415 return failure();
7416 }
7417
7418 return success();
7419}
7420
7421void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7422 p << " " << getMask();
7423 if (getPassthru())
7424 p << ", " << getPassthru();
7425
7426 // Print single masked operation and skip terminator.
7427 p << " { ";
7428 Block *singleBlock = &getMaskRegion().getBlocks().front();
7429 if (singleBlock && !singleBlock->getOperations().empty())
7430 p.printCustomOrGenericOp(&singleBlock->front());
7431 p << " }";
7432
7433 p.printOptionalAttrDict(getOperation()->getAttrs());
7434
7435 p << " : " << getMask().getType();
7436 if (getNumResults() > 0)
7437 p << " -> " << getResultTypes();
7438}
7439
7440void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7441 // 1. For an empty `vector.mask`, create a default terminator.
7442 if (region.empty() || region.front().empty()) {
7443 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7444 MaskOp>::ensureTerminator(region, builder, loc);
7445 return;
7446 }
7447
7448 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7449 Block &block = region.front();
7450 if (isa<vector::YieldOp>(block.back()))
7451 return;
7452
7453 // 3. For a non-empty `vector.mask` without an explicit terminator:
7454
7455 // Create default terminator if the number of masked operations is not
7456 // one. This case will trigger a verification failure.
7457 if (block.getOperations().size() != 1) {
7458 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7459 MaskOp>::ensureTerminator(region, builder, loc);
7460 return;
7461 }
7462
7463 // Create a terminator that yields the results from the masked operation.
7464 OpBuilder opBuilder(builder.getContext());
7465 Operation *maskedOp = &block.front();
7466 opBuilder.setInsertionPointToEnd(&block);
7467 vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7468}
7469
7470LogicalResult MaskOp::verify() {
7471 // Structural checks.
7472 Block &block = getMaskRegion().getBlocks().front();
7473 if (block.getOperations().empty())
7474 return emitOpError("expects a terminator within the mask region");
7475
7476 unsigned numMaskRegionOps = block.getOperations().size();
7477 if (numMaskRegionOps > 2)
7478 return emitOpError("expects only one operation to mask");
7479
7480 // Terminator checks.
7481 auto terminator = dyn_cast<vector::YieldOp>(block.back());
7482 if (!terminator)
7483 return emitOpError("expects a terminator within the mask region");
7484
7485 if (terminator->getNumOperands() != getNumResults())
7486 return emitOpError(
7487 "expects number of results to match mask region yielded values");
7488
7489 // Empty vector.mask. Nothing else to check.
7490 if (numMaskRegionOps == 1)
7491 return success();
7492
7493 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7494 if (!maskableOp)
7495 return emitOpError("expects a MaskableOpInterface within the mask region");
7496
7497 // Result checks.
7498 if (maskableOp->getNumResults() != getNumResults())
7499 return emitOpError("expects number of results to match maskable operation "
7500 "number of results");
7501
7502 if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7503 return emitOpError("expects all the results from the MaskableOpInterface "
7504 "to match all the values returned by the terminator");
7505
7506 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7507 return emitOpError(
7508 "expects result type to match maskable operation result type");
7509
7510 if (llvm::count_if(maskableOp->getResultTypes(),
7511 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7512 return emitOpError("multiple vector results not supported");
7513
7514 // Mask checks.
7515 Type expectedMaskType = maskableOp.getExpectedMaskType();
7516 if (getMask().getType() != expectedMaskType)
7517 return emitOpError("expects a ")
7518 << expectedMaskType << " mask for the maskable operation";
7519
7520 // Passthru checks.
7521 Value passthru = getPassthru();
7522 if (passthru) {
7523 if (!maskableOp.supportsPassthru())
7524 return emitOpError(
7525 "doesn't expect a passthru argument for this maskable operation");
7526
7527 if (maskableOp->getNumResults() != 1)
7528 return emitOpError("expects result when passthru argument is provided");
7529
7530 if (passthru.getType() != maskableOp->getResultTypes()[0])
7531 return emitOpError("expects passthru type to match result type");
7532 }
7533
7534 return success();
7535}
7536
7537/// Folds empty `vector.mask` with no passthru operand and with or without
7538/// return values. For example:
7539///
7540/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7541/// vector<8xi1> -> vector<8xf32>
7542/// %1 = user_op %0 : vector<8xf32>
7543///
7544/// becomes:
7545///
7546/// %0 = user_op %a : vector<8xf32>
7547///
7548/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7549/// as it requires creating new operations.
7550
7551static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7552 SmallVectorImpl<OpFoldResult> &results) {
7553 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7554 return failure();
7555
7556 Block *block = maskOp.getMaskBlock();
7557 auto terminator = cast<vector::YieldOp>(block->front());
7558 if (terminator.getNumOperands() == 0) {
7559 // `vector.mask` has no results, just remove the `vector.mask`.
7560 return success();
7561 }
7562
7563 // `vector.mask` has results, propagate the results.
7564 llvm::append_range(results, terminator.getOperands());
7565 return success();
7566}
7567
7568LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7569 SmallVectorImpl<OpFoldResult> &results) {
7570 if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7571 return success();
7572
7573 MaskFormat maskFormat = getMaskFormat(getMask());
7574 if (maskFormat != MaskFormat::AllTrue)
7575 return failure();
7576
7577 // Move maskable operation outside of the `vector.mask` region.
7578 Operation *maskableOp = getMaskableOp();
7579 maskableOp->dropAllUses();
7580 maskableOp->moveBefore(getOperation());
7581
7582 llvm::append_range(results, maskableOp->getResults());
7583 return success();
7584}
7585
7586/// Canonialize empty `vector.mask` operations that can't be handled in
7587/// `VectorMask::fold` as they require creating new operations.
7588///
7589/// Example 1: Empty `vector.mask` with passthru operand.
7590///
7591/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7592/// vector<8xi1> -> vector<8xf32>
7593///
7594/// becomes:
7595///
7596/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7597///
7598class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7599 using Base::Base;
7600
7601 LogicalResult matchAndRewrite(MaskOp maskOp,
7602 PatternRewriter &rewriter) const override {
7603 if (!maskOp.isEmpty())
7604 return failure();
7605
7606 if (!maskOp.hasPassthru())
7607 return failure();
7608
7609 Block *block = maskOp.getMaskBlock();
7610 auto terminator = cast<vector::YieldOp>(block->front());
7611 assert(terminator.getNumOperands() == 1 &&
7612 "expected one result when passthru is provided");
7613
7614 rewriter.replaceOpWithNewOp<arith::SelectOp>(
7615 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7616 terminator.getOperand(0), maskOp.getPassthru());
7617
7618 return success();
7619 }
7620};
7621
7622void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7623 MLIRContext *context) {
7624 results.add<CanonializeEmptyMaskOp>(context);
7625}
7626
7627// MaskingOpInterface definitions.
7628
7629/// Returns the operation masked by this 'vector.mask'.
7630Operation *MaskOp::getMaskableOp() {
7631 Block *block = getMaskBlock();
7632 if (block->getOperations().size() < 2)
7633 return nullptr;
7634
7635 return &block->front();
7636}
7637
7638/// Returns true if 'vector.mask' has a passthru value.
7639bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
7640
7641//===----------------------------------------------------------------------===//
7642// ScanOp
7643//===----------------------------------------------------------------------===//
7644
7645LogicalResult ScanOp::verify() {
7646 VectorType srcType = getSourceType();
7647 VectorType initialType = getInitialValueType();
7648 // Check reduction dimension < rank.
7649 int64_t srcRank = srcType.getRank();
7650 int64_t reductionDim = getReductionDim();
7651 if (reductionDim >= srcRank)
7652 return emitOpError("reduction dimension ")
7653 << reductionDim << " has to be less than " << srcRank;
7654
7655 // Check that rank(initial_value) = rank(src) - 1.
7656 int64_t initialValueRank = initialType.getRank();
7657 if (initialValueRank != srcRank - 1)
7658 return emitOpError("initial value rank ")
7659 << initialValueRank << " has to be equal to " << srcRank - 1;
7660
7661 // Check shapes of initial value and src.
7662 ArrayRef<int64_t> srcShape = srcType.getShape();
7663 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7664 SmallVector<int64_t> expectedShape;
7665 for (int i = 0; i < srcRank; i++) {
7666 if (i != reductionDim)
7667 expectedShape.push_back(srcShape[i]);
7668 }
7669 if (!llvm::equal(initialValueShapes, expectedShape)) {
7670 return emitOpError("incompatible input/initial value shapes");
7671 }
7672
7673 // Verify supported reduction kind.
7674 Type eltType = getDestType().getElementType();
7675 if (!isSupportedCombiningKind(getKind(), eltType))
7676 return emitOpError("unsupported reduction type ")
7677 << eltType << " for kind '" << stringifyCombiningKind(getKind())
7678 << "'";
7679
7680 return success();
7681}
7682
7684 RewritePatternSet &patterns, PatternBenefit benefit) {
7685 patterns
7686 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7687 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7688 StridedSliceConstantMaskFolder, TransposeFolder>(
7689 patterns.getContext(), benefit);
7690}
7691
7692Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
7693 CombiningKind kind, Value v1, Value acc,
7694 arith::FastMathFlagsAttr fastmath,
7695 Value mask) {
7696 Type t1 = getElementTypeOrSelf(v1.getType());
7697 Type tAcc = getElementTypeOrSelf(acc.getType());
7698 Value result;
7699
7700 switch (kind) {
7701 case CombiningKind::ADD:
7702 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7703 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
7704 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7705 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7706 else
7707 llvm_unreachable("invalid value types for ADD reduction");
7708 break;
7709 case CombiningKind::AND:
7710 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7711 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
7712 break;
7713 case CombiningKind::MAXNUMF:
7714 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7715 "expected float values");
7716 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7717 break;
7718 case CombiningKind::MAXIMUMF:
7719 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7720 "expected float values");
7721 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7722 break;
7723 case CombiningKind::MINNUMF:
7724 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7725 "expected float values");
7726 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7727 break;
7728 case CombiningKind::MINIMUMF:
7729 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7730 "expected float values");
7731 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7732 break;
7733 case CombiningKind::MAXSI:
7734 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7735 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7736 break;
7737 case CombiningKind::MINSI:
7738 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7739 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7740 break;
7741 case CombiningKind::MAXUI:
7742 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7743 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7744 break;
7745 case CombiningKind::MINUI:
7746 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7747 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7748 break;
7749 case CombiningKind::MUL:
7750 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7751 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
7752 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7753 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7754 else
7755 llvm_unreachable("invalid value types for MUL reduction");
7756 break;
7757 case CombiningKind::OR:
7758 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7759 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
7760 break;
7761 case CombiningKind::XOR:
7762 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7763 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7764 break;
7765 };
7766
7767 assert(result && "unknown CombiningKind");
7768 return selectPassthru(b, mask, result, acc);
7769}
7770
7771//===----------------------------------------------------------------------===//
7772// StepOp
7773//===----------------------------------------------------------------------===//
7774
7775void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7776 SetIntRangeFn setResultRanges) {
7777 auto resultType = cast<VectorType>(getType());
7778 if (resultType.isScalable()) {
7779 return;
7780 }
7781 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7782 APInt zero(bitwidth, 0);
7783 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7784 ConstantIntRanges result = {zero, high, zero, high};
7785 setResultRanges(getResult(), result);
7786}
7787
7788namespace {
7789
7790/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
7791/// constant large enough such that the result is the same at all indices.
7792///
7793/// For example, rewrite the 'greater than' comparison below,
7794///
7795/// ```mlir
7796/// %cst = arith.constant dense<7> : vector<3xindex>
7797/// %stp = vector.step : vector<3xindex>
7798/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
7799/// ```
7800///
7801/// as,
7802///
7803/// ```mlir
7804/// %out = arith.constant dense<false> : vector<3xi1>.
7805/// ```
7806///
7807/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
7808/// is false at ALL indices we fold. If the constant was 1, then
7809/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
7810/// conservatively preferring the 'compact' vector.step representation.
7811///
7812/// Note: this folder only works for the case where the constant (`%cst` above)
7813/// is the second operand of the comparison. The arith.cmpi canonicalizer will
7814/// ensure that constants are always second (on the right).
7815struct StepCompareFolder : public OpRewritePattern<StepOp> {
7816 using Base::Base;
7817
7818 LogicalResult matchAndRewrite(StepOp stepOp,
7819 PatternRewriter &rewriter) const override {
7820 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7821
7822 for (OpOperand &use : stepOp.getResult().getUses()) {
7823 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7824 if (!cmpiOp)
7825 continue;
7826
7827 // arith.cmpi canonicalizer makes constants final operands.
7828 const unsigned stepOperandNumber = use.getOperandNumber();
7829 if (stepOperandNumber != 0)
7830 continue;
7831
7832 // Check that operand 1 is a constant.
7833 unsigned constOperandNumber = 1;
7834 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7835 std::optional<int64_t> maybeConstValue =
7836 getConstantIntValue(otherOperand);
7837 if (!maybeConstValue.has_value())
7838 continue;
7839
7840 int64_t constValue = maybeConstValue.value();
7841 arith::CmpIPredicate pred = cmpiOp.getPredicate();
7842
7843 auto maybeSplat = [&]() -> std::optional<bool> {
7844 // Handle ult (unsigned less than) and uge (unsigned greater equal).
7845 if ((pred == arith::CmpIPredicate::ult ||
7846 pred == arith::CmpIPredicate::uge) &&
7847 stepSize <= constValue)
7848 return pred == arith::CmpIPredicate::ult;
7849
7850 // Handle ule and ugt.
7851 if ((pred == arith::CmpIPredicate::ule ||
7852 pred == arith::CmpIPredicate::ugt) &&
7853 stepSize - 1 <= constValue) {
7854 return pred == arith::CmpIPredicate::ule;
7855 }
7856
7857 // Handle eq and ne.
7858 if ((pred == arith::CmpIPredicate::eq ||
7859 pred == arith::CmpIPredicate::ne) &&
7860 stepSize <= constValue)
7861 return pred == arith::CmpIPredicate::ne;
7862
7863 return std::nullopt;
7864 }();
7865
7866 if (!maybeSplat.has_value())
7867 continue;
7868
7869 rewriter.setInsertionPointAfter(cmpiOp);
7870
7871 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7872 if (!type)
7873 continue;
7874
7875 auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
7876 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7877 type, boolAttr);
7878
7879 rewriter.replaceOp(cmpiOp, splat);
7880 return success();
7881 }
7882
7883 return failure();
7884 }
7885};
7886} // namespace
7887
7888void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
7889 MLIRContext *context) {
7890 results.add<StepCompareFolder>(context);
7891}
7892
7893//===----------------------------------------------------------------------===//
7894// Vector Masking Utilities
7895//===----------------------------------------------------------------------===//
7896
7897/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
7898/// as masked operation.
7899void mlir::vector::createMaskOpRegion(OpBuilder &builder,
7900 Operation *maskableOp) {
7901 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
7902 Block *insBlock = builder.getInsertionBlock();
7903 // Create a block and move the op to that block.
7904 insBlock->getOperations().splice(
7905 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
7906 YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
7907}
7908
7909/// Creates a vector.mask operation around a maskable operation. Returns the
7910/// vector.mask operation if the mask provided is valid. Otherwise, returns
7911/// the maskable operation itself.
7912Operation *mlir::vector::maskOperation(OpBuilder &builder,
7913 Operation *maskableOp, Value mask,
7914 Value passthru) {
7915 if (!mask)
7916 return maskableOp;
7917 if (passthru)
7918 return MaskOp::create(builder, maskableOp->getLoc(),
7919 maskableOp->getResultTypes(), mask, passthru,
7920 maskableOp, createMaskOpRegion);
7921 return MaskOp::create(builder, maskableOp->getLoc(),
7922 maskableOp->getResultTypes(), mask, maskableOp,
7924}
7925
7926/// Creates a vector select operation that picks values from `newValue` or
7927/// `passthru` for each result vector lane based on `mask`. This utility is used
7928/// to propagate the pass-thru value of vector.mask or for cases where only the
7929/// pass-thru value propagation is needed. VP intrinsics do not support
7930/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
7931/// usually able to match op + select patterns and fold them into a native
7932/// target instructions.
7933Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
7934 Value newValue, Value passthru) {
7935 if (!mask)
7936 return newValue;
7937
7938 return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
7939 mask, newValue, passthru);
7940}
7941
7942//===----------------------------------------------------------------------===//
7943// TableGen'd op method definitions
7944//===----------------------------------------------------------------------===//
7945
7946#define GET_ATTRDEF_CLASSES
7947#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7948
7949#define GET_OP_CLASSES
7950#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
b getContext())
auto load
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition VectorOps.cpp:70
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
Definition VectorOps.cpp:60
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
#define mul(a, b)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
bool empty()
Definition Block.h:158
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
iterator begin()
Definition Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:270
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:442
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
bool isF32() const
Definition Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:112
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:46
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition VectorOps.h:64
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition VectorOps.h:72
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const