MLIR 22.0.0git
VectorOps.cpp
Go to the documentation of this file.
1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
15
26#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Builders.h"
32#include "mlir/IR/IRMapping.h"
36#include "mlir/IR/ValueRange.h"
39#include "mlir/Support/LLVM.h"
41#include "llvm/ADT/ArrayRef.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/StringSet.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/Support/Casting.h"
47
48#include <cassert>
49#include <cstdint>
50#include <numeric>
51
52#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
53// Pull in all enum type and utility function definitions.
54#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
55
56using namespace mlir;
57using namespace mlir::vector;
58
59/// Helper enum to classify mask value.
60enum class MaskFormat {
64};
65
66/// Helper method to classify a mask value. Currently, the method
67/// looks "under the hood" of a constant value with dense attributes
68/// and a constant mask operation (since the client may be called at
69/// various stages during progressive lowering).
71 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
72 // Inspect constant dense values. We count up for bits that
73 // are set, count down for bits that are cleared, and bail
74 // when a mix is detected.
75 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
76 int64_t val = 0;
77 for (bool b : denseElts.getValues<bool>())
78 if (b && val >= 0)
79 val++;
80 else if (!b && val <= 0)
81 val--;
82 else
84 if (val > 0)
86 if (val < 0)
88 }
89 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
90 // Inspect constant mask index. If the index exceeds the
91 // dimension size, all bits are set. If the index is zero
92 // or less, no bits are set.
93 ArrayRef<int64_t> masks = m.getMaskDimSizes();
94 auto shape = m.getType().getShape();
95 bool allTrue = true;
96 bool allFalse = true;
97 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
98 if (maskIdx < dimSize)
99 allTrue = false;
100 if (maskIdx > 0)
101 allFalse = false;
102 }
103 if (allTrue)
104 return MaskFormat::AllTrue;
105 if (allFalse)
107 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
108 // Finds all-false create_masks. An all-true create_mask requires all
109 // dims to be constants, so that'll be folded to a constant_mask, then
110 // detected in the constant_mask case.
111 auto maskOperands = m.getOperands();
112 for (Value operand : maskOperands) {
113 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
114 int64_t dimSize =
115 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
116 if (dimSize <= 0)
118 }
119 }
120 return MaskFormat::Unknown;
121 }
122 return MaskFormat::Unknown;
123}
124
125/// Default callback to build a region with a 'vector.yield' terminator with no
126/// arguments.
128 vector::YieldOp::create(builder, loc);
129}
130
131// Helper for verifying combining kinds in contractions and reductions.
132static bool isSupportedCombiningKind(CombiningKind combiningKind,
133 Type elementType) {
134 switch (combiningKind) {
135 case CombiningKind::ADD:
136 case CombiningKind::MUL:
137 return elementType.isIntOrIndexOrFloat();
138 case CombiningKind::MINUI:
139 case CombiningKind::MINSI:
140 case CombiningKind::MAXUI:
141 case CombiningKind::MAXSI:
142 case CombiningKind::AND:
143 case CombiningKind::OR:
144 case CombiningKind::XOR:
145 return elementType.isIntOrIndex();
146 case CombiningKind::MINNUMF:
147 case CombiningKind::MAXNUMF:
148 case CombiningKind::MINIMUMF:
149 case CombiningKind::MAXIMUMF:
150 return llvm::isa<FloatType>(elementType);
151 }
152 return false;
153}
154
155/// Returns the effective rank of the vector to read/write for Xfer Ops
156///
157/// When the element type of the shaped type is _a scalar_, this will simply
158/// return the rank of the vector ( the result for xfer_read or the value to
159/// store for xfer_write).
160///
161/// When the element type of the base shaped type is _a vector_, returns the
162/// difference between the original vector type and the element type of the
163/// shaped type.
164///
165/// EXAMPLE 1 (element type is _a scalar_):
166/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167/// - shapedType.getElementType() = f32 (rank 0)
168/// - vectorType.getRank() = 2
169/// - Result = 2 - 0 = 2
170///
171/// EXAMPLE 2 (element type is _a vector_):
172/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
173/// - shapedType.getElementType() = vector<20xf32> (rank 1)
174/// - vectorType.getRank() = 1
175/// - Result = 1 - 1 = 0
176///
177/// This is used to determine the number of minor dimensions for identity maps
178/// in vector transfer Ops.
179static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
180 VectorType vectorType) {
181 unsigned elementVectorRank = 0;
182 VectorType elementVectorType =
183 llvm::dyn_cast<VectorType>(shapedType.getElementType());
184 if (elementVectorType)
185 elementVectorRank += elementVectorType.getRank();
186 return vectorType.getRank() - elementVectorRank;
187}
188
190 VectorType vectorType) {
191 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
192 // TODO: replace once we have 0-d vectors.
193 if (shapedType.getRank() == 0 &&
194 vectorType.getShape() == ArrayRef<int64_t>{1})
195 return AffineMap::get(
196 /*numDims=*/0, /*numSymbols=*/0,
197 getAffineConstantExpr(0, shapedType.getContext()));
199 shapedType.getRank(),
200 getEffectiveVectorRankForXferOp(shapedType, vectorType),
201 shapedType.getContext());
202}
203
204/// Check if `write` is of a constant splat and the masked `read` is padded with
205/// the same splat value -- meaning it could be the same value as the initial
206/// constant splat.
207static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
208 vector::TransferReadOp read) {
209 auto readMask = read.getMask();
210 auto writeMask = write.getMask();
211 // Check if the masks are consistent. The splat value could be the same if the
212 // read is masked (and padded with the splat value), and the write is unmasked
213 // or has the same mask. Note this does not allow the case where the write is
214 // masked and the read is unmasked, as then the read could be of more elements
215 // than the write (which may not be the same value).
216 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
217 if (!couldBeSameSplat)
218 return false;
219 // Check for constant splat (as the source of the write).
220 DenseElementsAttr splatAttr;
221 if (!matchPattern(write.getVector(),
222 m_Constant<DenseElementsAttr>(&splatAttr)) ||
223 !splatAttr.isSplat()) {
224 return false;
225 }
226 // The padding of the read and the constant splat value must be the same.
227 Attribute padAttr;
228 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
229 return false;
230 return padAttr == splatAttr.getSplatValue<Attribute>();
231}
232
233bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
234 vector::TransferReadOp read) {
235 return !defWrite.hasOutOfBoundsDim() &&
236 defWrite.getIndices() == read.getIndices() &&
237 defWrite.getVectorType() == read.getVectorType() &&
238 defWrite.getPermutationMap() == read.getPermutationMap() &&
239 ((!defWrite.getMask() && !read.getMask()) ||
241}
242
243bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
244 vector::TransferWriteOp priorWrite) {
245 return priorWrite.getIndices() == write.getIndices() &&
246 priorWrite.getMask() == write.getMask() &&
247 priorWrite.getVectorType() == write.getVectorType() &&
248 priorWrite.getPermutationMap() == write.getPermutationMap();
249}
250
252 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
253 bool testDynamicValueUsingBounds) {
254 // For simplicity only look at transfer of same type.
255 if (transferA.getVectorType() != transferB.getVectorType())
256 return false;
257 unsigned rankOffset = transferA.getLeadingShapedRank();
258 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
259 Value indexA = transferA.getIndices()[i];
260 Value indexB = transferB.getIndices()[i];
261 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
262 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
263
264 if (i < rankOffset) {
265 // For leading dimensions, if we can prove that index are different we
266 // know we are accessing disjoint slices.
267 if (cstIndexA.has_value() && cstIndexB.has_value()) {
268 if (*cstIndexA != *cstIndexB)
269 return true;
270 continue;
271 }
272 if (testDynamicValueUsingBounds) {
273 // First try to see if we can fully compose and simplify the affine
274 // expression as a fast track.
275 FailureOr<uint64_t> delta =
277 if (succeeded(delta) && *delta != 0)
278 return true;
279
280 FailureOr<bool> testEqual =
282 if (succeeded(testEqual) && !testEqual.value())
283 return true;
284 }
285 } else {
286 // For this dimension, we slice a part of the memref we need to make sure
287 // the intervals accessed don't overlap.
288 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
289 if (cstIndexA.has_value() && cstIndexB.has_value()) {
290 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
291 if (distance >= vectorDim)
292 return true;
293 continue;
294 }
295 if (testDynamicValueUsingBounds) {
296 // First try to see if we can fully compose and simplify the affine
297 // expression as a fast track.
298 FailureOr<int64_t> delta =
300 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
301 return true;
302
303 FailureOr<int64_t> computeDelta =
305 if (succeeded(computeDelta)) {
306 if (std::abs(computeDelta.value()) >= vectorDim)
307 return true;
308 }
309 }
310 }
311 }
312 return false;
313}
314
315bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
316 VectorTransferOpInterface transferB,
317 bool testDynamicValueUsingBounds) {
318 if (transferA.getBase() != transferB.getBase())
319 return false;
320 return isDisjointTransferIndices(transferA, transferB,
321 testDynamicValueUsingBounds);
322}
323
324// Helper to iterate over n-D vector slice elements. Calculate the next
325// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
326// Modifies the `position` in place. Returns a failure when `position` becomes
327// the end position.
328static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
330 ArrayRef<int64_t> offsets) {
331 for (auto [posInDim, dimSize, offsetInDim] :
332 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
333 ++posInDim;
334 if (posInDim < dimSize + offsetInDim)
335 return success();
336
337 // Carry the overflow to the next loop iteration.
338 posInDim = offsetInDim;
339 }
340
341 return failure();
342}
343
344/// Returns the integer numbers in `values`. `values` are expected to be
345/// constant operations.
348 llvm::transform(values, std::back_inserter(ints), [](Value value) {
349 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
350 assert(constOp && "Unexpected non-constant index");
351 return constOp.value();
352 });
353 return ints;
354}
355
356/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
357/// be constant operations.
360 llvm::transform(
361 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
362 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
363 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
364 });
365 return ints;
366}
367
368/// Convert `foldResults` into Values. Integer attributes are converted to
369/// constant op.
371 ArrayRef<OpFoldResult> foldResults) {
372 SmallVector<Value> values;
373 llvm::transform(foldResults, std::back_inserter(values),
374 [&](OpFoldResult foldResult) {
375 if (auto attr = dyn_cast<Attribute>(foldResult))
377 builder, loc, cast<IntegerAttr>(attr).getInt())
378 .getResult();
379
380 return cast<Value>(foldResult);
381 });
382 return values;
383}
384
385std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
386 if (value.getDefiningOp<vector::VectorScaleOp>())
387 return 1;
388 auto mul = value.getDefiningOp<arith::MulIOp>();
389 if (!mul)
390 return {};
391 auto lhs = mul.getLhs();
392 auto rhs = mul.getRhs();
393 if (lhs.getDefiningOp<vector::VectorScaleOp>())
394 return getConstantIntValue(rhs);
395 if (rhs.getDefiningOp<vector::VectorScaleOp>())
396 return getConstantIntValue(lhs);
397 return {};
398}
399
400/// Converts numeric attributes to the expected type. Supports
401/// integer-to-integer and float-to-integer conversions. Returns the original
402/// attribute if no conversion is needed or supported.
403static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
404 // Integer-to-integer conversion
405 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
406 if (auto intType = dyn_cast<IntegerType>(expectedType)) {
407 if (intAttr.getType() != expectedType)
408 return IntegerAttr::get(expectedType, intAttr.getInt());
409 }
410 return attr;
411 }
412
413 // Float-to-integer bitcast (preserves bit representation)
414 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
415 auto intType = dyn_cast<IntegerType>(expectedType);
416 if (!intType)
417 return attr;
418
419 APFloat floatVal = floatAttr.getValue();
420 APInt intVal = floatVal.bitcastToAPInt();
421 return IntegerAttr::get(expectedType, intVal);
422 }
423
424 return attr;
425}
426
427//===----------------------------------------------------------------------===//
428// CombiningKindAttr
429//===----------------------------------------------------------------------===//
430
431namespace mlir {
432namespace vector {
433namespace detail {
435 using KeyTy = uint64_t;
436
438
439 bool operator==(const KeyTy &key) const { return value == key; }
440
442 const KeyTy &key) {
443 return new (allocator.allocate<BitmaskEnumStorage>())
445 }
446
448};
449} // namespace detail
450} // namespace vector
451} // namespace mlir
452
453//===----------------------------------------------------------------------===//
454// VectorDialect
455//===----------------------------------------------------------------------===//
456
457namespace {
458/// This class defines the interface for handling inlining with vector dialect
459/// operations.
460struct VectorInlinerInterface : public DialectInlinerInterface {
461 using DialectInlinerInterface::DialectInlinerInterface;
462
463 /// All vector dialect ops can be inlined.
464 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
465 return true;
466 }
467};
468} // namespace
469
470void VectorDialect::initialize() {
471 addAttributes<
472#define GET_ATTRDEF_LIST
473#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
474 >();
475
476 addOperations<
477#define GET_OP_LIST
478#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
479 >();
480
481 addInterfaces<VectorInlinerInterface>();
482
483 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
484 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
485 YieldOp>();
486 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
487 TransferWriteOp>();
488 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
489 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
490 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
491}
492
493/// Materialize a single constant operation from a given attribute value with
494/// the desired resultant type.
495Operation *VectorDialect::materializeConstant(OpBuilder &builder,
496 Attribute value, Type type,
497 Location loc) {
498 if (isa<ub::PoisonAttrInterface>(value))
499 return value.getDialect().materializeConstant(builder, value, type, loc);
500
501 return arith::ConstantOp::materialize(builder, value, type, loc);
502}
503
505 return builder.getIntegerType(64);
506}
507
509 ArrayRef<int64_t> values) {
510 return builder.getI64ArrayAttr(values);
511}
512
513//===----------------------------------------------------------------------===//
514// MultiDimReductionOp
515//===----------------------------------------------------------------------===//
516
517void vector::MultiDimReductionOp::build(OpBuilder &builder,
518 OperationState &result, Value source,
519 Value acc, ArrayRef<bool> reductionMask,
520 CombiningKind kind) {
521 SmallVector<int64_t> reductionDims;
522 for (const auto &en : llvm::enumerate(reductionMask))
523 if (en.value())
524 reductionDims.push_back(en.index());
525 build(builder, result, kind, source, acc, reductionDims);
526}
527
528OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
529 // Single parallel dim, this is a noop.
530 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
531 return getSource();
532 return {};
533}
534
535std::optional<SmallVector<int64_t, 4>>
536MultiDimReductionOp::getShapeForUnroll() {
537 return llvm::to_vector<4>(getSourceVectorType().getShape());
538}
539
540LogicalResult MultiDimReductionOp::verify() {
541 SmallVector<int64_t> targetShape;
542 SmallVector<bool> scalableDims;
543 Type inferredReturnType;
544 auto sourceScalableDims = getSourceVectorType().getScalableDims();
545 for (auto [dimIdx, dimSize] :
546 llvm::enumerate(getSourceVectorType().getShape()))
547 if (!llvm::any_of(getReductionDims(),
548 [dimIdx = dimIdx](int64_t reductionDimIdx) {
549 return reductionDimIdx == static_cast<int64_t>(dimIdx);
550 })) {
551 targetShape.push_back(dimSize);
552 scalableDims.push_back(sourceScalableDims[dimIdx]);
553 }
554 // TODO: update to also allow 0-d vectors when available.
555 if (targetShape.empty())
556 inferredReturnType = getSourceVectorType().getElementType();
557 else
558 inferredReturnType = VectorType::get(
559 targetShape, getSourceVectorType().getElementType(), scalableDims);
560 if (getType() != inferredReturnType)
561 return emitOpError() << "destination type " << getType()
562 << " is incompatible with source type "
563 << getSourceVectorType();
564
565 return success();
566}
567
568/// Returns the mask type expected by this operation.
569Type MultiDimReductionOp::getExpectedMaskType() {
570 auto vecType = getSourceVectorType();
571 return VectorType::get(vecType.getShape(),
572 IntegerType::get(vecType.getContext(), /*width=*/1),
573 vecType.getScalableDims());
574}
575
576namespace {
577// Only unit dimensions that are being reduced are folded. If the dimension is
578// unit, but not reduced, it is not folded, thereby keeping the output type the
579// same. If not all dimensions which are reduced are of unit dimension, this
580// transformation does nothing. This is just a generalization of
581// ElideSingleElementReduction for ReduceOp.
582struct ElideUnitDimsInMultiDimReduction
583 : public OpRewritePattern<MultiDimReductionOp> {
584 using Base::Base;
585
586 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
587 PatternRewriter &rewriter) const override {
588 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
589 for (const auto &dim : enumerate(shape)) {
590 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
591 return failure();
592 }
593
594 // Vector mask setup.
595 OpBuilder::InsertionGuard guard(rewriter);
596 Operation *rootOp;
597 Value mask;
598 if (reductionOp.isMasked()) {
599 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
600 rootOp = reductionOp.getMaskingOp();
601 mask = reductionOp.getMaskingOp().getMask();
602 } else {
603 rootOp = reductionOp;
604 }
605
606 Location loc = reductionOp.getLoc();
607 Value acc = reductionOp.getAcc();
608 Value cast;
609 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
610 if (mask) {
611 VectorType newMaskType =
612 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
613 dstVecType.getScalableDims());
614 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
615 }
616 cast = vector::ShapeCastOp::create(
617 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
618 } else {
619 // This means we are reducing all the dimensions, and all reduction
620 // dimensions are of size 1. So a simple extraction would do.
621 if (mask)
622 mask = vector::ExtractOp::create(rewriter, loc, mask);
623 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
624 }
625
626 Value result =
627 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
628 cast, /*fastmath=*/nullptr, mask);
629 rewriter.replaceOp(rootOp, result);
630 return success();
631 }
632};
633} // namespace
634
635void MultiDimReductionOp::getCanonicalizationPatterns(
636 RewritePatternSet &results, MLIRContext *context) {
637 results.add<ElideUnitDimsInMultiDimReduction>(context);
638}
639
640//===----------------------------------------------------------------------===//
641// ReductionOp
642//===----------------------------------------------------------------------===//
643
644void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
645 CombiningKind kind, Value vector,
646 arith::FastMathFlags fastMathFlags) {
647 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
648}
649
650void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
651 CombiningKind kind, Value vector, Value acc,
652 arith::FastMathFlags fastMathFlags) {
653 build(builder, result,
654 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
655 acc, fastMathFlags);
656}
657
658LogicalResult ReductionOp::verify() {
659 // Verify for 0-D and 1-D vector.
660 int64_t rank = getSourceVectorType().getRank();
661 if (rank > 1)
662 return emitOpError("unsupported reduction rank: ") << rank;
663
664 // Verify supported reduction kind.
665 Type eltType = getDest().getType();
666 if (!isSupportedCombiningKind(getKind(), eltType))
667 return emitOpError("unsupported reduction type '")
668 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
669 << "'";
670
671 return success();
672}
673
674// MaskableOpInterface methods.
675
676/// Returns the mask type expected by this operation.
677Type ReductionOp::getExpectedMaskType() {
678 auto vecType = getSourceVectorType();
679 return VectorType::get(vecType.getShape(),
680 IntegerType::get(vecType.getContext(), /*width=*/1),
681 vecType.getScalableDims());
682}
683
685 OpBuilder &builder, Location loc,
686 Value vector) {
687 switch (op) {
688 case arith::AtomicRMWKind::addf:
689 case arith::AtomicRMWKind::addi:
690 return vector::ReductionOp::create(builder, vector.getLoc(),
691 CombiningKind::ADD, vector);
692 case arith::AtomicRMWKind::mulf:
693 case arith::AtomicRMWKind::muli:
694 return vector::ReductionOp::create(builder, vector.getLoc(),
695 CombiningKind::MUL, vector);
696 case arith::AtomicRMWKind::minimumf:
697 return vector::ReductionOp::create(builder, vector.getLoc(),
698 CombiningKind::MINIMUMF, vector);
699 case arith::AtomicRMWKind::mins:
700 return vector::ReductionOp::create(builder, vector.getLoc(),
701 CombiningKind::MINSI, vector);
702 case arith::AtomicRMWKind::minu:
703 return vector::ReductionOp::create(builder, vector.getLoc(),
704 CombiningKind::MINUI, vector);
705 case arith::AtomicRMWKind::maximumf:
706 return vector::ReductionOp::create(builder, vector.getLoc(),
707 CombiningKind::MAXIMUMF, vector);
708 case arith::AtomicRMWKind::maxs:
709 return vector::ReductionOp::create(builder, vector.getLoc(),
710 CombiningKind::MAXSI, vector);
711 case arith::AtomicRMWKind::maxu:
712 return vector::ReductionOp::create(builder, vector.getLoc(),
713 CombiningKind::MAXUI, vector);
714 case arith::AtomicRMWKind::andi:
715 return vector::ReductionOp::create(builder, vector.getLoc(),
716 CombiningKind::AND, vector);
717 case arith::AtomicRMWKind::ori:
718 return vector::ReductionOp::create(builder, vector.getLoc(),
719 CombiningKind::OR, vector);
720 case arith::AtomicRMWKind::minnumf:
721 return vector::ReductionOp::create(builder, vector.getLoc(),
722 CombiningKind::MINNUMF, vector);
723 case arith::AtomicRMWKind::maxnumf:
724 return vector::ReductionOp::create(builder, vector.getLoc(),
725 CombiningKind::MAXNUMF, vector);
726 case arith::AtomicRMWKind::xori:
727 return vector::ReductionOp::create(builder, vector.getLoc(),
728 CombiningKind::XOR, vector);
729 default:
730 (void)emitOptionalError(loc, "Reduction operation type not supported");
731 break;
732 }
733 return nullptr;
734}
735
736std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
737 return llvm::to_vector<4>(getSourceVectorType().getShape());
738}
739
740namespace {
741struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
742 using Base::Base;
743
744 LogicalResult matchAndRewrite(ReductionOp reductionOp,
745 PatternRewriter &rewriter) const override {
746 // Vector mask setup.
747 OpBuilder::InsertionGuard guard(rewriter);
748 auto maskableOp =
749 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
750 Operation *rootOp;
751 Value mask;
752 if (maskableOp.isMasked()) {
753 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
754 rootOp = maskableOp.getMaskingOp();
755 mask = maskableOp.getMaskingOp().getMask();
756 } else {
757 rootOp = reductionOp;
758 }
759
760 auto vectorType = reductionOp.getSourceVectorType();
761 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
762 return failure();
763
764 Location loc = reductionOp.getLoc();
765 if (mask)
766 mask = ExtractOp::create(rewriter, loc, mask);
767 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
768
769 if (Value acc = reductionOp.getAcc())
770 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
771 result, acc,
772 reductionOp.getFastmathAttr(), mask);
773
774 rewriter.replaceOp(rootOp, result);
775 return success();
776 }
777};
778} // namespace
779
780void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
781 MLIRContext *context) {
782 results.add<ElideSingleElementReduction>(context);
783}
784
785//===----------------------------------------------------------------------===//
786// ContractionOp
787//===----------------------------------------------------------------------===//
788
789void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
791 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
792 ArrayRef<IteratorType> iteratorTypes) {
793 result.addOperands({lhs, rhs, acc});
794 result.addTypes(acc.getType());
795 result.addAttribute(
796 getIndexingMapsAttrName(result.name),
797 builder.getAffineMapArrayAttr(
798 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
799 result.addAttribute(
800 getIteratorTypesAttrName(result.name),
801 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
802 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
803 return IteratorTypeAttr::get(builder.getContext(), t);
804 }))));
805}
806
807void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
809 ArrayAttr indexingMaps,
810 ArrayAttr iteratorTypes) {
811 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
812 ContractionOp::getDefaultKind());
813}
814
815void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
817 ArrayAttr indexingMaps,
818 ArrayAttr iteratorTypes, CombiningKind kind) {
819 result.addOperands({lhs, rhs, acc});
820 result.addTypes(acc.getType());
821 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
822 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
823 result.addAttribute(getKindAttrName(result.name),
824 CombiningKindAttr::get(builder.getContext(), kind));
825}
826
827ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
833 Type resultType;
834 auto loc = parser.getCurrentLocation();
835 DictionaryAttr dictAttr;
836 // TODO: Unify linalg op attribute parsing.
837 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
838 parser.parseComma() || parser.parseOperand(rhsInfo) ||
839 parser.parseComma() || parser.parseOperand(accInfo) ||
840 parser.parseTrailingOperandList(masksInfo) ||
841 parser.parseOptionalAttrDict(result.attributes) ||
842 parser.parseColonTypeList(types) ||
843 parser.parseKeywordType("into", resultType) ||
844 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
845 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
846 parser.resolveOperand(accInfo, resultType, result.operands) ||
847 parser.addTypeToList(resultType, result.types))
848 return failure();
849 result.attributes.append(dictAttr.getValue().begin(),
850 dictAttr.getValue().end());
851
852 // Convert array of string into an array of IteratyType enums. This is needed,
853 // because tests still use the old format when 'iterator_types' attribute is
854 // represented as an array of strings.
855 // TODO: Remove this conversion once tests are fixed.
856 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
857 result.attributes.get(getIteratorTypesAttrName(result.name)));
858 if (!iteratorTypes) {
859 return parser.emitError(loc)
860 << "expected " << getIteratorTypesAttrName(result.name)
861 << " array attribute";
862 }
863
864 SmallVector<Attribute> iteratorTypeAttrs;
865
866 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
867 auto maybeIteratorType = symbolizeIteratorType(s);
868 if (!maybeIteratorType.has_value())
869 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
870
871 iteratorTypeAttrs.push_back(
872 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
873 }
874 result.attributes.set(getIteratorTypesAttrName(result.name),
875 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
876
877 if (!result.attributes.get(getKindAttrName(result.name))) {
878 result.addAttribute(
879 getKindAttrName(result.name),
880 CombiningKindAttr::get(result.getContext(),
881 ContractionOp::getDefaultKind()));
882 }
883 if (masksInfo.empty())
884 return success();
885 if (masksInfo.size() != 2)
886 return parser.emitError(parser.getNameLoc(),
887 "expected zero or exactly 2 vector mask operands");
888 auto lhsType = llvm::cast<VectorType>(types[0]);
889 auto rhsType = llvm::cast<VectorType>(types[1]);
890 auto maskElementType = parser.getBuilder().getI1Type();
891 std::array<VectorType, 2> maskTypes = {
892 VectorType::Builder(lhsType).setElementType(maskElementType),
893 VectorType::Builder(rhsType).setElementType(maskElementType)};
894 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
895 return failure();
896 return success();
897}
898
899void ContractionOp::print(OpAsmPrinter &p) {
900 // TODO: Unify printing code with linalg ops.
901 auto attrNames = getTraitAttrNames();
902 llvm::StringSet<> traitAttrsSet;
903 traitAttrsSet.insert_range(attrNames);
905 for (auto attr : (*this)->getAttrs()) {
906 if (attr.getName() == getIteratorTypesAttrName()) {
907 auto iteratorTypes =
908 llvm::cast<ArrayAttr>(attr.getValue())
909 .getAsValueRange<IteratorTypeAttr, IteratorType>();
910 // Convert IteratorType enums into the string representation. This is
911 // needed, because tests still use the old format when 'iterator_types'
912 // attribute is represented as an array of strings.
913 // TODO: Remove this conversion once tests are fixed.
914 SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
915 llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
916 return StringAttr::get(getContext(), stringifyIteratorType(t));
917 }));
918
919 attrs.emplace_back(getIteratorTypesAttrName(),
920 ArrayAttr::get(getContext(), iteratorTypeNames));
921 } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
922 attrs.push_back(attr);
923 }
924
925 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
926 p << " " << dictAttr << " " << getLhs() << ", ";
927 p << getRhs() << ", " << getAcc();
928
929 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
930 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
931 << getResultType();
932}
933
934static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
935 const std::vector<std::pair<int64_t, int64_t>> &map) {
936 for (auto &dimPair : map) {
937 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
938 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
939 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
940 return false;
941 }
942 return true;
943}
944
945static LogicalResult verifyOutputShape(
946 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
947 Type resType,
948 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
949 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
950 DenseSet<int64_t> lhsContractingDimSet;
951 DenseSet<int64_t> rhsContractingDimSet;
952 for (auto &dimPair : contractingDimMap) {
953 lhsContractingDimSet.insert(dimPair.first);
954 rhsContractingDimSet.insert(dimPair.second);
955 }
956 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
957 llvm::make_second_range(batchDimMap));
958
959 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
960 SmallVector<int64_t, 4> expectedResultDims;
961 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
962 if (lhsContractingDimSet.count(i) > 0)
963 continue;
964 expectedResultDims.push_back(lhsType.getDimSize(i));
965 }
966
967 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
968 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
969 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
970 continue;
971 expectedResultDims.push_back(rhsType.getDimSize(i));
972 }
973
974 // Verify 'expectedResultDims'.
975 if (expectedResultDims.empty()) {
976 // No batch or free dimension implies a scalar result.
977 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
978 return op.emitOpError("invalid accumulator/result vector shape");
979 } else {
980 // At least one batch or free dimension implies a vector result.
981 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
982 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
983 if (!resVectorType || !accVectorType)
984 return op.emitOpError("invalid accumulator/result vector shape");
985
986 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
987 // types fully define the result vector type. This assumes the affine maps
988 // are well-formed, which must have been verified already.
989 MLIRContext *ctx = op.getContext();
990 AffineMap lhsMap = op.getIndexingMapsArray()[0];
991 AffineMap rhsMap = op.getIndexingMapsArray()[1];
992 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
993 return op.emitOpError(
994 "expected all dimensions to be either a LHS or a RHS dimension");
996 for (auto pair :
997 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
998 VectorType v = pair.first;
999 auto map = pair.second;
1000 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1001 unsigned pos = map.getDimPosition(idx);
1002 if (!extents[pos])
1003 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
1004 }
1005 }
1006 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
1007 return op.emitOpError("expected all dimensions to get an extent as "
1008 "either a LHS or a RHS dimension");
1009
1010 AffineMap resMap = op.getIndexingMapsArray()[2];
1011 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1012 /*symbolCount=*/0, extents, ctx);
1013 // Compose the resMap with the extentsMap, which is a constant map.
1014 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1015 assert(llvm::all_of(expectedMap.getResults(),
1016 llvm::IsaPred<AffineConstantExpr>) &&
1017 "expected constant extent along all dimensions.");
1018 // Extract the expected shape and build the type.
1019 auto expectedShape = llvm::to_vector<4>(
1020 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
1021 return cast<AffineConstantExpr>(e).getValue();
1022 }));
1023 auto expected =
1024 VectorType::get(expectedShape, resVectorType.getElementType(),
1025 resVectorType.getScalableDims());
1026 if (resVectorType != expected || accVectorType != expected)
1027 return op.emitOpError(
1028 "invalid accumulator/result vector shape, expected: ")
1029 << expected;
1030 }
1031 return success();
1032}
1033
1034LogicalResult ContractionOp::verify() {
1035 VectorType lhsType = getLhsType();
1036 VectorType rhsType = getRhsType();
1037 Type accType = getAccType();
1038 Type resType = getResultType();
1039
1040 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1041 if (!lhsType.getElementType().isSignlessInteger())
1042 return emitOpError("only supports signless integer types");
1043 }
1044
1045 // Verify that an indexing map was specified for each vector operand.
1046 if (getIndexingMapsArray().size() != 3)
1047 return emitOpError("expected an indexing map for each vector operand");
1048
1049 // Verify that each index map has 'numIterators' inputs, no symbols, and
1050 // that the number of map outputs equals the rank of its associated
1051 // vector operand.
1052 unsigned numIterators = getIteratorTypes().getValue().size();
1053 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1054 auto index = it.index();
1055 auto map = it.value();
1056 if (map.getNumSymbols() != 0)
1057 return emitOpError("expected indexing map ")
1058 << index << " to have no symbols";
1059 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1060 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1061 // Verify that the map has the right number of inputs, outputs, and indices.
1062 // This also correctly accounts for (..) -> () for rank-0 results.
1063 if (map.getNumDims() != numIterators)
1064 return emitOpError("expected indexing map ")
1065 << index << " to have " << numIterators << " number of inputs";
1066 if (map.getNumResults() != rank)
1067 return emitOpError("expected indexing map ")
1068 << index << " to have " << rank << " number of outputs";
1069 if (!map.isProjectedPermutation())
1070 return emitOpError("expected indexing map ")
1071 << index << " to be a projected permutation of its inputs";
1072 }
1073
1074 auto contractingDimMap = getContractingDimMap();
1075 auto batchDimMap = getBatchDimMap();
1076
1077 // Verify at least one contracting dimension pair was specified.
1078 if (contractingDimMap.empty())
1079 return emitOpError("expected at least one contracting dimension pair");
1080
1081 // Verify contracting dimension map was properly constructed.
1082 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1083 return emitOpError("invalid contracting dimension map");
1084
1085 // Verify batch dimension map was properly constructed.
1086 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1087 return emitOpError("invalid batch dimension map");
1088
1089 // Verify 'accType' and 'resType' shape.
1090 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1091 contractingDimMap, batchDimMap)))
1092 return failure();
1093
1094 // Verify supported combining kind.
1095 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1096 auto elementType = vectorType ? vectorType.getElementType() : resType;
1097 if (!isSupportedCombiningKind(getKind(), elementType))
1098 return emitOpError("unsupported contraction type");
1099
1100 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1101 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1102}
1103
1104// MaskableOpInterface methods.
1105
1106/// Returns the mask type expected by this operation. Mostly used for
1107/// verification purposes. It requires the operation to be vectorized."
1108Type ContractionOp::getExpectedMaskType() {
1109 auto indexingMaps = this->getIndexingMapsArray();
1110 AffineMap lhsIdxMap = indexingMaps[0];
1111 AffineMap rhsIdxMap = indexingMaps[1];
1112 VectorType lhsType = this->getLhsType();
1113 VectorType rhsType = this->getRhsType();
1114
1115 unsigned numVecDims = lhsIdxMap.getNumDims();
1116 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1117 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1118
1119 // Using the information in the indexing maps, extract the size of each
1120 // dimension in the vector.contract operation from the two input operands.
1121 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1122 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1123 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1124 lhsType.getScalableDims()[dimIdx];
1125 }
1126 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1127 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1128 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1129 rhsType.getScalableDims()[dimIdx];
1130 }
1131
1132 assert(ShapedType::isStaticShape(maskShape) &&
1133 "Mask shape couldn't be computed");
1134
1135 return VectorType::get(maskShape,
1136 IntegerType::get(lhsType.getContext(), /*width=*/1),
1137 maskShapeScalableDims);
1138}
1139
1140SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1141 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1142 getIteratorTypesAttrName(), getKindAttrName()};
1143}
1144
1146 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1147 if (targetExpr == map.getResult(i))
1148 return i;
1149 return -1;
1150}
1151
1152static std::vector<std::pair<int64_t, int64_t>>
1153getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1154 IteratorType targetIteratorType, MLIRContext *context) {
1155 std::vector<std::pair<int64_t, int64_t>> dimMap;
1156 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1157 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1158 if (iteratorType != targetIteratorType)
1159 continue;
1160 // Search lhs/rhs map results for 'targetExpr'.
1161 auto targetExpr = getAffineDimExpr(it.index(), context);
1162 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1163 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1164 if (lhsDim >= 0 && rhsDim >= 0)
1165 dimMap.emplace_back(lhsDim, rhsDim);
1166 }
1167 return dimMap;
1168}
1169
1170void ContractionOp::getIterationBounds(
1171 SmallVectorImpl<int64_t> &iterationBounds) {
1172 auto lhsShape = getLhsType().getShape();
1173 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1174 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1175 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1176 // Search lhs/rhs map results for 'targetExpr'.
1177 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1178 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1179 if (iteratorType == IteratorType::reduction) {
1180 // Get reduction dim size from lhs shape (same size in rhsShape).
1181 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1182 assert(lhsDimIndex >= 0);
1183 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1184 continue;
1185 }
1186 // Get parallel dimension size from result shape.
1187 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1188 assert(resDimIndex >= 0);
1189 assert(resVectorType != nullptr);
1190 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1191 }
1192}
1193
1194void ContractionOp::getIterationIndexMap(
1195 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1196 unsigned numMaps = getIndexingMapsArray().size();
1197 iterationIndexMap.resize(numMaps);
1198 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1199 auto index = it.index();
1200 auto map = it.value();
1201 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1202 auto dim = cast<AffineDimExpr>(map.getResult(i));
1203 iterationIndexMap[index][dim.getPosition()] = i;
1204 }
1205 }
1206}
1207
1208std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1209 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1210 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1211 getContext());
1212}
1213
1214std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1215 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1216 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1217 getContext());
1218}
1219
1220std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1222 getIterationBounds(shape);
1223 return shape;
1224}
1225
1226/// Return a fused vector::ContractionOp which represents a patterns such as:
1227///
1228/// ```mlir
1229/// %c0 = vector.constant 0: ...
1230/// %c = vector.contract %a, %b, %c0: ...
1231/// %e = add %c, %d: ...
1232/// ```
1233///
1234/// by:
1235///
1236/// ```mlir
1237/// %e = vector.contract %a, %b, %d: ...
1238/// ```
1239///
1240/// Return null if the canonicalization does not apply.
1241// TODO: This should be a folding of Add into Contract in core but while they
1242// live in different dialects, it is not possible without unnatural
1243// dependencies.
1244template <typename AddOpType>
1245struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1246 using OpRewritePattern<AddOpType>::OpRewritePattern;
1247
1248 LogicalResult matchAndRewrite(AddOpType addOp,
1249 PatternRewriter &rewriter) const override {
1250 auto canonicalize = [&](Value maybeContraction,
1251 Value otherOperand) -> vector::ContractionOp {
1252 vector::ContractionOp contractionOp =
1253 dyn_cast_or_null<vector::ContractionOp>(
1254 maybeContraction.getDefiningOp());
1255 if (!contractionOp)
1256 return vector::ContractionOp();
1257 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1258 contractionOp.getAcc().getDefiningOp())) {
1259 if (maybeZero.getValue() ==
1260 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1261 IRMapping bvm;
1262 bvm.map(contractionOp.getAcc(), otherOperand);
1263 auto newContraction =
1264 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1265 rewriter.replaceOp(addOp, newContraction.getResult());
1266 return newContraction;
1267 }
1268 }
1269 return vector::ContractionOp();
1270 };
1271
1272 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1273 vector::ContractionOp contract = canonicalize(a, b);
1274 contract = contract ? contract : canonicalize(b, a);
1275 return contract ? success() : failure();
1276 }
1277};
1278
1279void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1280 MLIRContext *context) {
1283}
1284
1285// Returns `true` if `index` is either within [0, maxIndex) or equal to
1286// `poisonValue`.
1288 int64_t maxIndex) {
1289 return index == poisonValue || (index >= 0 && index < maxIndex);
1290}
1291
1292//===----------------------------------------------------------------------===//
1293// ExtractOp
1294//===----------------------------------------------------------------------===//
1295
1296void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1297 SetIntRangeFn setResultRanges) {
1298 setResultRanges(getResult(), argRanges.front());
1299}
1300
1301void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1302 Value source) {
1303 auto vectorTy = cast<VectorType>(source.getType());
1304 build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1305}
1306
1307void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1308 Value source, int64_t position) {
1309 build(builder, result, source, ArrayRef<int64_t>{position});
1310}
1311
1312void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1313 Value source, OpFoldResult position) {
1314 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1315}
1316
1317void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1318 Value source, ArrayRef<int64_t> position) {
1319 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1320 builder.getDenseI64ArrayAttr(position));
1321}
1322
1323void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1324 Value source, ArrayRef<OpFoldResult> position) {
1325 SmallVector<int64_t> staticPos;
1326 SmallVector<Value> dynamicPos;
1327 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1328 build(builder, result, source, dynamicPos,
1329 builder.getDenseI64ArrayAttr(staticPos));
1330}
1331
1332LogicalResult
1333ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1334 ExtractOp::Adaptor adaptor,
1335 SmallVectorImpl<Type> &inferredReturnTypes) {
1336 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1337 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1338 vectorType.getRank()) {
1339 inferredReturnTypes.push_back(vectorType.getElementType());
1340 } else {
1341 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1342 vectorType.getRank());
1343 inferredReturnTypes.push_back(VectorType::get(
1344 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1345 vectorType.getScalableDims().drop_front(n)));
1346 }
1347 return success();
1348}
1349
1350bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1351 // Allow extracting 1-element vectors instead of scalars.
1352 auto isCompatible = [](TypeRange l, TypeRange r) {
1353 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1354 return vectorType && vectorType.getShape().equals({1}) &&
1355 vectorType.getElementType() == r.front();
1356 };
1357 if (l.size() == 1 && r.size() == 1 &&
1358 (isCompatible(l, r) || isCompatible(r, l)))
1359 return true;
1360 return l == r;
1361}
1362
1363LogicalResult vector::ExtractOp::verify() {
1364 if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1365 if (resTy.getRank() == 0)
1366 return emitError(
1367 "expected a scalar instead of a 0-d vector as the result type");
1368
1369 // Note: This check must come before getMixedPosition() to prevent a crash.
1370 auto dynamicMarkersCount =
1371 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1372 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1373 return emitOpError(
1374 "mismatch between dynamic and static positions (kDynamic marker but no "
1375 "corresponding dynamic position) -- this can only happen due to an "
1376 "incorrect fold/rewrite");
1377 auto position = getMixedPosition();
1378 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1379 return emitOpError(
1380 "expected position attribute of rank no greater than vector rank");
1381 for (auto [idx, pos] : llvm::enumerate(position)) {
1382 if (auto attr = dyn_cast<Attribute>(pos)) {
1383 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1385 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1386 return emitOpError("expected position attribute #")
1387 << (idx + 1)
1388 << " to be a non-negative integer smaller than the "
1389 "corresponding vector dimension or poison (-1)";
1390 }
1391 }
1392 }
1393 return success();
1394}
1395
1396template <typename IntType>
1398 return llvm::to_vector<4>(llvm::map_range(
1399 arrayAttr.getAsRange<IntegerAttr>(),
1400 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1401}
1402
1403/// Fold the result of chains of ExtractOp in place by simply concatenating the
1404/// positions.
1405static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1406 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1407 return failure();
1408
1409 // TODO: Canonicalization for dynamic position not implemented yet.
1410 if (extractOp.hasDynamicPosition())
1411 return failure();
1412
1413 SmallVector<int64_t> globalPosition;
1414 ExtractOp currentOp = extractOp;
1415 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1416 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1417 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1418 currentOp = nextOp;
1419 // TODO: Canonicalization for dynamic position not implemented yet.
1420 if (currentOp.hasDynamicPosition())
1421 return failure();
1422 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1423 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1424 }
1425 extractOp.setOperand(0, currentOp.getSource());
1426 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1427 OpBuilder b(extractOp.getContext());
1428 std::reverse(globalPosition.begin(), globalPosition.end());
1429 extractOp.setStaticPosition(globalPosition);
1430 return success();
1431}
1432
1433namespace {
1434/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1435/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1436/// Compose TransposeOp permutations as we walk back.
1437/// This helper class keeps an updated extraction position `extractPosition`
1438/// with extra trailing sentinels.
1439/// The sentinels encode the internal transposition status of the result vector.
1440/// As we iterate, extractPosition is permuted and updated.
1441class ExtractFromInsertTransposeChainState {
1442public:
1443 ExtractFromInsertTransposeChainState(ExtractOp e);
1444
1445 /// Iterate over producing insert and transpose ops until we find a fold.
1446 Value fold();
1447
1448private:
1449 /// Return true if the vector at position `a` is contained within the vector
1450 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1451 /// is a prefix of `b`.
1452 template <typename ContainerA, typename ContainerB>
1453 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1454 return a.size() <= b.size() &&
1455 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1456 }
1457
1458 /// Return true if the vector at position `a` intersects the vector at
1459 /// position `b`. Under insert/extract semantics, this is the same as equality
1460 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1461 /// Comparison is on the common prefix (i.e. zip).
1462 template <typename ContainerA, typename ContainerB>
1463 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1464 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1465 if (elemA < 0 || elemB < 0)
1466 continue;
1467 if (elemA != elemB)
1468 return false;
1469 }
1470 return true;
1471 }
1472
1473 /// Folding is only possible in the absence of an internal permutation in the
1474 /// result vector.
1475 bool canFold() {
1476 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1477 }
1478
1479 // Helper to get the next defining op of interest.
1480 void updateStateForNextIteration(Value v) {
1481 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1482 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1483 };
1484
1485 // Case 1. If we hit a transpose, just compose the map and iterate.
1486 // Invariant: insert + transpose do not change rank, we can always compose.
1487 LogicalResult handleTransposeOp();
1488
1489 // Case 2: the insert position matches extractPosition exactly, early return.
1490 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1491
1492 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1493 /// portion of the source of the insert.
1494 /// Example:
1495 /// ```
1496 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1497 /// // extractPosition == [1, 2, 3]
1498 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1499 /// // can fold to vector.extract %source[0, 3]
1500 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1501 /// ```
1502 /// To traverse through %source, we need to set the leading dims to 0 and
1503 /// drop the extra leading dims.
1504 /// This method updates the internal state.
1505 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1506
1507 /// Try to fold in place to extract(source, extractPosition) and return the
1508 /// folded result. Return null if folding is not possible (e.g. due to an
1509 /// internal transposition in the result).
1510 Value tryToFoldExtractOpInPlace(Value source);
1511
1512 ExtractOp extractOp;
1513 int64_t vectorRank;
1514 int64_t extractedRank;
1515
1516 InsertOp nextInsertOp;
1517 TransposeOp nextTransposeOp;
1518
1519 /// Sentinel values that encode the internal permutation status of the result.
1520 /// They are set to (-1, ... , -k) at the beginning and appended to
1521 /// `extractPosition`.
1522 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1523 /// ensure that there is no internal transposition.
1524 /// Internal transposition cannot be accounted for with a folding pattern.
1525 // TODO: We could relax the internal transposition with an extra transposition
1526 // operation in a future canonicalizer.
1527 SmallVector<int64_t> sentinels;
1528 SmallVector<int64_t> extractPosition;
1529};
1530} // namespace
1531
1532ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1533 ExtractOp e)
1534 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1535 extractedRank(extractOp.getNumIndices()) {
1536 assert(vectorRank >= extractedRank && "Extracted position overflow");
1537 sentinels.reserve(vectorRank - extractedRank);
1538 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1539 sentinels.push_back(-(i + 1));
1540 extractPosition.assign(extractOp.getStaticPosition().begin(),
1541 extractOp.getStaticPosition().end());
1542 llvm::append_range(extractPosition, sentinels);
1543}
1544
1545// Case 1. If we hit a transpose, just compose the map and iterate.
1546// Invariant: insert + transpose do not change rank, we can always compose.
1547LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1548 // TODO: Canonicalization for dynamic position not implemented yet.
1549 if (extractOp.hasDynamicPosition())
1550 return failure();
1551
1552 if (!nextTransposeOp)
1553 return failure();
1555 nextTransposeOp.getPermutation(), extractOp.getContext()));
1557 return success();
1558}
1559
1560// Case 2: the insert position matches extractPosition exactly, early return.
1561LogicalResult
1562ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1563 Value &res) {
1564 // TODO: Canonicalization for dynamic position not implemented yet.
1565 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1566 return failure();
1567
1568 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1569 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1570 return failure();
1571 // Case 2.a. early-exit fold.
1572 res = nextInsertOp.getValueToStore();
1573 // Case 2.b. if internal transposition is present, canFold will be false.
1574 return success(canFold());
1575}
1576
1577/// Case 3: if inserted position is a prefix of extractPosition,
1578/// extract a portion of the source of the insertion.
1579/// This method updates the internal state.
1580LogicalResult
1581ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1582 // TODO: Canonicalization for dynamic position not implemented yet.
1583 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1584 return failure();
1585
1586 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1587 if (!isContainedWithin(insertedPos, extractPosition))
1588 return failure();
1589 // Set leading dims to zero.
1590 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1591 // Drop extra leading dims.
1592 extractPosition.erase(extractPosition.begin(),
1593 extractPosition.begin() + insertedPos.size());
1594 extractedRank = extractPosition.size() - sentinels.size();
1595 // Case 3.a. early-exit fold (break and delegate to post-while path).
1596 res = nextInsertOp.getValueToStore();
1597 // Case 3.b. if internal transposition is present, canFold will be false.
1598 return success();
1599}
1600
1601/// Try to fold in place to extract(source, extractPosition) and return the
1602/// folded result. Return null if folding is not possible (e.g. due to an
1603/// internal transposition in the result).
1604Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1605 Value source) {
1606 // TODO: Canonicalization for dynamic position not implemented yet.
1607 if (extractOp.hasDynamicPosition())
1608 return Value();
1609
1610 // If we can't fold (either internal transposition, or nothing to fold), bail.
1611 bool nothingToFold = (source == extractOp.getSource());
1612 if (nothingToFold || !canFold())
1613 return Value();
1614
1615 // Otherwise, fold by updating the op inplace and return its result.
1616 OpBuilder b(extractOp.getContext());
1617 extractOp.setStaticPosition(
1618 ArrayRef(extractPosition).take_front(extractedRank));
1619 extractOp.getSourceMutable().assign(source);
1620 return extractOp.getResult();
1621}
1622
1623/// Iterate over producing insert and transpose ops until we find a fold.
1624Value ExtractFromInsertTransposeChainState::fold() {
1625 // TODO: Canonicalization for dynamic position not implemented yet.
1626 if (extractOp.hasDynamicPosition())
1627 return Value();
1628
1629 Value valueToExtractFrom = extractOp.getSource();
1630 updateStateForNextIteration(valueToExtractFrom);
1631 while (nextInsertOp || nextTransposeOp) {
1632 // Case 1. If we hit a transpose, just compose the map and iterate.
1633 // Invariant: insert + transpose do not change rank, we can always compose.
1634 if (succeeded(handleTransposeOp())) {
1635 valueToExtractFrom = nextTransposeOp.getVector();
1636 updateStateForNextIteration(valueToExtractFrom);
1637 continue;
1638 }
1639
1640 Value result;
1641 // Case 2: the position match exactly.
1642 if (succeeded(handleInsertOpWithMatchingPos(result)))
1643 return result;
1644
1645 // Case 3: if the inserted position is a prefix of extractPosition, we can
1646 // just extract a portion of the source of the insert.
1647 if (succeeded(handleInsertOpWithPrefixPos(result)))
1648 return tryToFoldExtractOpInPlace(result);
1649
1650 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1651 // values. This is a more difficult case and we bail.
1652 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1653 if (isContainedWithin(extractPosition, insertedPos) ||
1654 intersectsWhereNonNegative(extractPosition, insertedPos))
1655 return Value();
1656
1657 // Case 5: No intersection, we forward the extract to insertOp.dest().
1658 valueToExtractFrom = nextInsertOp.getDest();
1659 updateStateForNextIteration(valueToExtractFrom);
1660 }
1661 // If after all this we can fold, go for it.
1662 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1663}
1664
1665/// Returns true if the operation has a 0-D vector type operand or result.
1667 auto hasZeroDimVectorType = [](Type type) -> bool {
1668 auto vecType = dyn_cast<VectorType>(type);
1669 return vecType && vecType.getRank() == 0;
1670 };
1671
1672 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1673 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1674}
1675
1676/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1677/// considered to be 'broadcastlike'.
1678static bool isBroadcastLike(Operation *op) {
1679 if (isa<BroadcastOp>(op))
1680 return true;
1681
1682 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1683 if (!shapeCast)
1684 return false;
1685
1686 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1687 // Checking that the destination shape has a prefix of 1s is not sufficient,
1688 // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1689 // is that the source shape is a suffix of the destination shape.
1690 VectorType srcType = shapeCast.getSourceVectorType();
1691 ArrayRef<int64_t> srcShape = srcType.getShape();
1692 uint64_t srcRank = srcType.getRank();
1693 ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1694 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1695}
1696
1697/// Fold extract(broadcast(X)) to either extract(X) or just X.
1698///
1699/// Example:
1700///
1701/// broadcast extract [1][2]
1702/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1703///
1704/// becomes
1705/// extract [1]
1706/// (3,4) -------------------------------------> (4)
1707///
1708///
1709/// The variable names used in this implementation correspond to the above
1710/// shapes as,
1711///
1712/// - (3, 4) is `input` shape.
1713/// - (2, 3, 4) is `broadcast` shape.
1714/// - (4) is `extract` shape.
1715///
1716/// This folding is possible when the suffix of `input` shape is the same as
1717/// `extract` shape.
1718static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1719
1720 Operation *defOp = extractOp.getSource().getDefiningOp();
1721 if (!defOp || !isBroadcastLike(defOp))
1722 return Value();
1723
1724 Value input = defOp->getOperand(0);
1725
1726 // Replace extract(broadcast(X)) with X
1727 if (extractOp.getType() == input.getType())
1728 return input;
1729
1730 // Get required types and ranks in the chain
1731 // input -> broadcast -> extract
1732 // (scalars are treated as rank-0).
1733 auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1734 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1735 unsigned inputRank = inputType ? inputType.getRank() : 0;
1736 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1737 unsigned extractRank = extractType ? extractType.getRank() : 0;
1738
1739 // Cannot do without the broadcast if overall the rank increases.
1740 if (extractRank > inputRank)
1741 return Value();
1742
1743 // The above condition guarantees that input is a vector.
1744 assert(inputType && "input must be a vector type because of previous checks");
1745 ArrayRef<int64_t> inputShape = inputType.getShape();
1746
1747 // In the case where there is a broadcast dimension in the suffix, it is not
1748 // possible to replace extract(broadcast(X)) with extract(X). Example:
1749 //
1750 // broadcast extract
1751 // (1) --------> (3,4) ------> (4)
1752 if (extractType &&
1753 extractType.getShape() != inputShape.take_back(extractRank))
1754 return Value();
1755
1756 // Replace extract(broadcast(X)) with extract(X).
1757 // First, determine the new extraction position.
1758 unsigned deltaOverall = inputRank - extractRank;
1759 unsigned deltaBroadcast = broadcastRank - inputRank;
1760 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1761 SmallVector<OpFoldResult> newPositions(deltaOverall);
1762 IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1763 for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1764 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1765 }
1766 auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1767 extractOp->setOperands(
1768 llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1769 extractOp.setStaticPosition(staticPos);
1770 return extractOp.getResult();
1771}
1772
1773/// Fold extractOp coming from ShuffleOp.
1774///
1775/// Example:
1776///
1777/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1778/// : vector<8xf32>, vector<8xf32>
1779/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1780/// ->
1781/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1782///
1783static Value foldExtractFromShuffle(ExtractOp extractOp) {
1784 // Dynamic positions are not folded as the resulting code would be more
1785 // complex than the input code.
1786 if (extractOp.hasDynamicPosition())
1787 return Value();
1788
1789 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1790 if (!shuffleOp)
1791 return Value();
1792
1793 // TODO: 0-D or multi-dimensional vectors not supported yet.
1794 if (shuffleOp.getResultVectorType().getRank() != 1)
1795 return Value();
1796
1797 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1798 auto shuffleMask = shuffleOp.getMask();
1799 int64_t extractIdx = extractOp.getStaticPosition()[0];
1800 int64_t shuffleIdx = shuffleMask[extractIdx];
1801
1802 // Find the shuffled vector to extract from based on the shuffle index.
1803 if (shuffleIdx < inputVecSize) {
1804 extractOp.setOperand(0, shuffleOp.getV1());
1805 extractOp.setStaticPosition({shuffleIdx});
1806 } else {
1807 extractOp.setOperand(0, shuffleOp.getV2());
1808 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1809 }
1810
1811 return extractOp.getResult();
1812}
1813
1814// Fold extractOp with source coming from ShapeCast op.
1815static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1816 // TODO: Canonicalization for dynamic position not implemented yet.
1817 if (extractOp.hasDynamicPosition())
1818 return Value();
1819
1820 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1821 if (!shapeCastOp)
1822 return Value();
1823
1824 // Get the nth dimension size starting from lowest dimension.
1825 auto getDimReverse = [](VectorType type, int64_t n) {
1826 return type.getShape().take_back(n + 1).front();
1827 };
1828 int64_t destinationRank =
1829 llvm::isa<VectorType>(extractOp.getType())
1830 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1831 : 0;
1832 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1833 return Value();
1834 if (destinationRank > 0) {
1835 auto destinationType =
1836 llvm::cast<VectorType>(extractOp.getResult().getType());
1837 for (int64_t i = 0; i < destinationRank; i++) {
1838 // The lowest dimension of the destination must match the lowest
1839 // dimension of the shapecast op source.
1840 // TODO: This case could be support in a canonicalization pattern.
1841 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1842 getDimReverse(destinationType, i))
1843 return Value();
1844 }
1845 }
1846 // Extract the strides associated with the extract op vector source. Then use
1847 // this to calculate a linearized position for the extract.
1848 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1849 std::reverse(extractedPos.begin(), extractedPos.end());
1851 int64_t stride = 1;
1852 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1853 strides.push_back(stride);
1854 stride *=
1855 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1856 }
1857
1858 int64_t position = linearize(extractedPos, strides);
1859 // Then extract the strides associated to the shapeCast op vector source and
1860 // delinearize the position using those strides.
1861 SmallVector<int64_t, 4> newStrides;
1862 int64_t numDimension =
1863 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1864 stride = 1;
1865 for (int64_t i = 0; i < numDimension; i++) {
1866 newStrides.push_back(stride);
1867 stride *=
1868 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1869 }
1870 std::reverse(newStrides.begin(), newStrides.end());
1871 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1872 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1873 OpBuilder b(extractOp.getContext());
1874 extractOp.setStaticPosition(newPosition);
1875 extractOp.setOperand(0, shapeCastOp.getSource());
1876 return extractOp.getResult();
1877}
1878
1879/// Fold an ExtractOp from ExtractStridedSliceOp.
1880static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1881 // TODO: Canonicalization for dynamic position not implemented yet.
1882 if (extractOp.hasDynamicPosition())
1883 return Value();
1884
1885 auto extractStridedSliceOp =
1886 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1887 if (!extractStridedSliceOp)
1888 return Value();
1889
1890 // 0-D vectors not supported.
1891 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1892 if (hasZeroDimVectors(extractStridedSliceOp))
1893 return Value();
1894
1895 // Return if 'extractStridedSliceOp' has non-unit strides.
1896 if (extractStridedSliceOp.hasNonUnitStrides())
1897 return Value();
1898
1899 // Trim offsets for dimensions fully extracted.
1900 auto sliceOffsets =
1901 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1902 while (!sliceOffsets.empty()) {
1903 size_t lastOffset = sliceOffsets.size() - 1;
1904 if (sliceOffsets.back() != 0 ||
1905 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1906 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1907 break;
1908 sliceOffsets.pop_back();
1909 }
1910 unsigned destinationRank = 0;
1911 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1912 destinationRank = vecType.getRank();
1913 // The dimensions of the result need to be untouched by the
1914 // extractStridedSlice op.
1915 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1916 sliceOffsets.size())
1917 return Value();
1918
1919 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1920 assert(extractedPos.size() >= sliceOffsets.size());
1921 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1922 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1923 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1924
1925 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1926 OpBuilder b(extractOp.getContext());
1927 extractOp.setStaticPosition(extractedPos);
1928 return extractOp.getResult();
1929}
1930
1931/// Fold extract_op fed from a chain of insertStridedSlice ops.
1932static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1933 // TODO: Canonicalization for dynamic position not implemented yet.
1934 if (extractOp.hasDynamicPosition())
1935 return Value();
1936
1937 int64_t destinationRank =
1938 llvm::isa<VectorType>(extractOp.getType())
1939 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1940 : 0;
1941 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1942 if (!insertOp)
1943 return Value();
1944
1945 // 0-D vectors not supported.
1946 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1947 if (hasZeroDimVectors(insertOp))
1948 return Value();
1949
1950 while (insertOp) {
1951 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1952 insertOp.getSourceVectorType().getRank();
1953 if (destinationRank > insertOp.getSourceVectorType().getRank())
1954 return Value();
1955 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1956 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1957
1958 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1959 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1960 }))
1961 return Value();
1962 bool disjoint = false;
1963 SmallVector<int64_t, 4> offsetDiffs;
1964 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1965 int64_t start = insertOffsets[dim];
1966 int64_t size =
1967 (dim < insertRankDiff)
1968 ? 1
1969 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1970 int64_t end = start + size;
1971 int64_t offset = extractOffsets[dim];
1972 // Check if the start of the extract offset is in the interval inserted.
1973 if (start <= offset && offset < end) {
1974 if (dim >= insertRankDiff)
1975 offsetDiffs.push_back(offset - start);
1976 continue;
1977 }
1978 disjoint = true;
1979 break;
1980 }
1981 // The extract element chunk overlap with the vector inserted.
1982 if (!disjoint) {
1983 // If any of the inner dimensions are only partially inserted we have a
1984 // partial overlap.
1985 int64_t srcRankDiff =
1986 insertOp.getSourceVectorType().getRank() - destinationRank;
1987 for (int64_t i = 0; i < destinationRank; i++) {
1988 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1989 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1990 insertRankDiff))
1991 return Value();
1992 }
1993 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1994 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1995 OpBuilder b(extractOp.getContext());
1996 extractOp.setStaticPosition(offsetDiffs);
1997 return extractOp.getResult();
1998 }
1999 // If the chunk extracted is disjoint from the chunk inserted, keep
2000 // looking in the insert chain.
2001 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2002 }
2003 return Value();
2004}
2005
2006/// Try to fold the extraction of a scalar from a vector defined by
2007/// vector.from_elements. E.g.:
2008///
2009/// %0 = vector.from_elements %a, %b : vector<2xf32>
2010/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2011/// ==> fold to %a
2012static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2013 // Dynamic extractions cannot be folded.
2014 if (extractOp.hasDynamicPosition())
2015 return {};
2016
2017 // Look for extract(from_elements).
2018 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2019 if (!fromElementsOp)
2020 return {};
2021
2022 // Scalable vectors are not supported.
2023 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2024 if (vecType.isScalable())
2025 return {};
2026
2027 // Only extractions of scalars are supported.
2028 int64_t rank = vecType.getRank();
2029 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2030 if (extractOp.getType() != vecType.getElementType())
2031 return {};
2032 assert(static_cast<int64_t>(indices.size()) == rank &&
2033 "unexpected number of indices");
2034
2035 // Compute flattened/linearized index and fold to operand.
2036 int flatIndex = 0;
2037 int stride = 1;
2038 for (int i = rank - 1; i >= 0; --i) {
2039 flatIndex += indices[i] * stride;
2040 stride *= vecType.getDimSize(i);
2041 }
2042 return fromElementsOp.getElements()[flatIndex];
2043}
2044
2045/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2046/// then fold it.
2047template <typename OpType, typename AdaptorType>
2048static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2049 SmallVectorImpl<Value> &operands) {
2050 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2051 OperandRange dynamicPosition = op.getDynamicPosition();
2052 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2054 if constexpr (std::is_same_v<OpType, ExtractOp>)
2055 vectorShape = op.getSourceVectorType().getShape();
2056 else
2057 vectorShape = op.getDestVectorType().getShape();
2058
2059 // If the dynamic operands is empty, it is returned directly.
2060 if (!dynamicPosition.size())
2061 return {};
2062
2063 // `index` is used to iterate over the `dynamicPosition`.
2064 unsigned index = 0;
2065
2066 // `opChange` is a flag. If it is true, it means to update `op` in place.
2067 bool opChange = false;
2068 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2069 if (ShapedType::isStatic(staticPosition[i]))
2070 continue;
2071 Attribute positionAttr = dynamicPositionAttr[index];
2072 Value position = dynamicPosition[index++];
2073 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2074 int64_t value = attr.getInt();
2075 // Do not fold if the value is out of bounds (-1 signifies a poison
2076 // value rather than OOB index).
2077 if (value >= -1 && value < vectorShape[i]) {
2078 staticPosition[i] = attr.getInt();
2079 opChange = true;
2080 continue;
2081 }
2082 }
2083 operands.push_back(position);
2084 }
2085
2086 if (opChange) {
2087 op.setStaticPosition(staticPosition);
2088 op.getOperation()->setOperands(operands);
2089 // Return the original result to indicate an in-place folding happened.
2090 return op.getResult();
2091 }
2092 return {};
2093}
2094
2095/// Fold an insert or extract operation into an poison value when a poison index
2096/// is found at any dimension of the static position.
2098 ArrayRef<int64_t> staticPos,
2099 int64_t poisonVal) {
2100 if (!is_contained(staticPos, poisonVal))
2101 return {};
2102
2103 return ub::PoisonAttr::get(context);
2104}
2105
2106/// Fold a vector extract from is a poison source.
2108 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2109 return srcAttr;
2110
2111 return {};
2112}
2113
2114/// Fold a vector extract extracting from a DenseElementsAttr.
2116 Attribute srcAttr) {
2117 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2118 if (!denseAttr) {
2119 return {};
2120 }
2121
2122 if (denseAttr.isSplat()) {
2123 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2124 if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2125 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2126 return newAttr;
2127 }
2128
2129 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2130 if (vecTy.isScalable())
2131 return {};
2132
2133 if (extractOp.hasDynamicPosition()) {
2134 return {};
2135 }
2136
2137 // Materializing subsets of a large constant array can generally lead to
2138 // explosion in IR size because of different combination of subsets that
2139 // can exist. However, vector.extract is a restricted form of subset
2140 // extract where you can only extract non-overlapping (or the same) subset for
2141 // a given rank of the subset. Because of this property, the IR size can only
2142 // increase at most by `rank * size(array)` from a single constant array being
2143 // extracted by multiple extracts.
2144
2145 // Calculate the linearized position of the continuous chunk of elements to
2146 // extract.
2147 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2148 copy(extractOp.getStaticPosition(), completePositions.begin());
2149 int64_t startPos =
2150 linearize(completePositions, computeStrides(vecTy.getShape()));
2151 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2152
2153 TypedAttr newAttr;
2154 if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2155 SmallVector<Attribute> elementValues(
2156 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2157 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2158 } else {
2159 newAttr = *denseValuesBegin;
2160 }
2161
2162 return newAttr;
2163}
2164
2165OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2166 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2167 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2168 // mismatch).
2169 if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2170 return getSource();
2171 if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2172 return res;
2173 // Fold `arith.constant` indices into the `vector.extract` operation.
2174 // Do not stop here as this fold may enable subsequent folds that require
2175 // constant indices.
2176 SmallVector<Value> operands = {getSource()};
2177 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2178
2179 if (auto res = foldPoisonIndexInsertExtractOp(
2180 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2181 return res;
2182 if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2183 return res;
2184 if (succeeded(foldExtractOpFromExtractChain(*this)))
2185 return getResult();
2186 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2187 return res;
2188 if (auto res = foldExtractFromBroadcast(*this))
2189 return res;
2190 if (auto res = foldExtractFromShuffle(*this))
2191 return res;
2192 if (auto res = foldExtractFromShapeCast(*this))
2193 return res;
2194 if (auto val = foldExtractFromExtractStrided(*this))
2195 return val;
2196 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2197 return val;
2198 if (auto val = foldScalarExtractFromFromElements(*this))
2199 return val;
2200
2201 return inplaceFolded;
2202}
2203
2204namespace {
2205
2206// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2207class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2208public:
2209 using Base::Base;
2210
2211 LogicalResult matchAndRewrite(ExtractOp extractOp,
2212 PatternRewriter &rewriter) const override {
2213
2214 Operation *defOp = extractOp.getSource().getDefiningOp();
2215 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2216 if (!defOp || !isBroadcastLike(defOp) || !outType)
2217 return failure();
2218
2219 Value source = defOp->getOperand(0);
2220 if (isBroadcastableTo(source.getType(), outType) !=
2221 BroadcastableToResult::Success)
2222 return failure();
2223
2224 rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2225 return success();
2226 }
2227};
2228
2229// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2230class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2231public:
2232 using Base::Base;
2233
2234 LogicalResult matchAndRewrite(ExtractOp extractOp,
2235 PatternRewriter &rewriter) const override {
2236 auto createMaskOp =
2237 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2238 if (!createMaskOp)
2239 return failure();
2240
2241 VectorType extractedMaskType =
2242 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2243
2244 if (!extractedMaskType)
2245 return failure();
2246
2247 auto maskOperands = createMaskOp.getOperands();
2248 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2249 VectorType maskType = createMaskOp.getVectorType();
2250
2251 bool containsUnknownDims = false;
2252 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2253
2254 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2255 dimIdx++) {
2256 int64_t pos = extractOpPos[dimIdx];
2257 Value operand = maskOperands[dimIdx];
2258 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2259 if (!constantOp) {
2260 // Bounds of this dim unknown.
2261 containsUnknownDims = true;
2262 continue;
2263 }
2264
2265 int64_t createMaskBound =
2266 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2267
2268 if (pos != ShapedType::kDynamic) {
2269 // If any position is outside the range from the `create_mask`, then the
2270 // extracted mask will be all-false.
2271 allFalse |= pos >= createMaskBound;
2272 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2273 // This dim is not all-true and since this is a dynamic index we don't
2274 // know if the extraction is within the true or false region.
2275 // Note: Zero dims have already handled via getMaskFormat().
2276 containsUnknownDims = true;
2277 }
2278 }
2279
2280 if (allFalse) {
2281 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2282 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2283 } else if (!containsUnknownDims) {
2284 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2285 extractOp, extractedMaskType,
2286 maskOperands.drop_front(extractOpPos.size()));
2287 } else {
2288 return failure();
2289 }
2290 return success();
2291 }
2292};
2293
2294// Folds extract(shape_cast(..)) into shape_cast when the total element count
2295// does not change.
2296LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2297 PatternRewriter &rewriter) {
2298 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2299 if (!castOp)
2300 return failure();
2301
2302 VectorType sourceType = castOp.getSourceVectorType();
2303 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2304 if (!targetType)
2305 return failure();
2306
2307 if (sourceType.getNumElements() != targetType.getNumElements())
2308 return failure();
2309
2310 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2311 castOp.getSource());
2312 return success();
2313}
2314
2315/// Try to canonicalize the extraction of a subvector from a vector defined by
2316/// vector.from_elements. E.g.:
2317///
2318/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2319/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2320/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2321LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2322 PatternRewriter &rewriter) {
2323 // Dynamic positions are not supported.
2324 if (extractOp.hasDynamicPosition())
2325 return failure();
2326
2327 // Scalar extracts are handled by the folder.
2328 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2329 if (!resultType)
2330 return failure();
2331
2332 // Look for extracts from a from_elements op.
2333 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2334 if (!fromElementsOp)
2335 return failure();
2336 VectorType inputType = fromElementsOp.getType();
2337
2338 // Scalable vectors are not supported.
2339 if (resultType.isScalable() || inputType.isScalable())
2340 return failure();
2341
2342 // Compute the position of first extracted element and flatten/linearize the
2343 // position.
2344 SmallVector<int64_t> firstElementPos =
2345 llvm::to_vector(extractOp.getStaticPosition());
2346 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2347 int flatIndex = 0;
2348 int stride = 1;
2349 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2350 flatIndex += firstElementPos[i] * stride;
2351 stride *= inputType.getDimSize(i);
2352 }
2353
2354 // Replace the op with a smaller from_elements op.
2355 rewriter.replaceOpWithNewOp<FromElementsOp>(
2356 extractOp, resultType,
2357 fromElementsOp.getElements().slice(flatIndex,
2358 resultType.getNumElements()));
2359 return success();
2360}
2361
2362} // namespace
2363
2364void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2365 MLIRContext *context) {
2366 results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2367 results.add(foldExtractFromShapeCastToShapeCast);
2368 results.add(foldExtractFromFromElements);
2369}
2370
2372 SmallVectorImpl<int64_t> &results) {
2373 for (auto attr : arrayAttr)
2374 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2375}
2376
2377//===----------------------------------------------------------------------===//
2378// FmaOp
2379//===----------------------------------------------------------------------===//
2380
2381std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2382 return llvm::to_vector<4>(getVectorType().getShape());
2383}
2384
2385//===----------------------------------------------------------------------===//
2386// ToElementsOp
2387//===----------------------------------------------------------------------===//
2388
2389/// Returns true if all the `operands` are defined by `defOp`.
2390/// Otherwise, returns false.
2391static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2392 if (operands.empty())
2393 return false;
2394
2395 return llvm::all_of(operands, [&](Value operand) {
2396 Operation *currentDef = operand.getDefiningOp();
2397 return currentDef == defOp;
2398 });
2399}
2400
2401/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2402/// (%e0, %e1, ...). For example:
2403///
2404/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2405/// %1:3 = vector.to_elements %0 : vector<3xf32>
2406/// user_op %1#0, %1#1, %1#2
2407///
2408/// becomes:
2409///
2410/// user_op %a, %b, %c
2411///
2412static LogicalResult
2413foldToElementsFromElements(ToElementsOp toElementsOp,
2415 auto fromElementsOp =
2416 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2417 if (!fromElementsOp)
2418 return failure();
2419
2420 llvm::append_range(results, fromElementsOp.getElements());
2421 return success();
2422}
2423
2424/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2425///
2426/// Example:
2427/// %b = vector.broadcast %x : i32 to vector<3xf32>
2428/// %e:3 = vector.to_elements %b : vector<3xf32>
2429/// user_op %e#0, %e#1, %e#2
2430/// becomes:
2431/// user_op %x, %x, %x
2432///
2433/// The vector source case is handled by a canonicalization pattern.
2434static LogicalResult
2435foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2437 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2438 if (!bcastOp)
2439 return failure();
2440 // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2441 if (isa<VectorType>(bcastOp.getSource().getType()))
2442 return failure();
2443
2444 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2445
2446 Value scalar = bcastOp.getSource();
2447 results.assign(resultVecType.getNumElements(), scalar);
2448 return success();
2449}
2450
2451LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2452 SmallVectorImpl<OpFoldResult> &results) {
2453 if (succeeded(foldToElementsFromElements(*this, results)))
2454 return success();
2455 return foldToElementsOfBroadcast(*this, results);
2456}
2457
2458LogicalResult
2459ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2460 ToElementsOp::Adaptor adaptor,
2461 SmallVectorImpl<Type> &inferredReturnTypes) {
2462 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2463 Type elType = vecType.getElementType();
2464 inferredReturnTypes.append(vecType.getNumElements(), elType);
2465 return success();
2466}
2467
2468/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2469/// vector.
2470/// - Build `vector.to_elements %v` and remap each destination element to the
2471/// corresponding source element using broadcast rules (match or 1 →
2472/// replicate).
2473///
2474/// Example:
2475/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2476/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2477/// becomes:
2478/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2479/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2480/// // %src_elems#1, %src_elems#0, %src_elems#1
2481struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2482 using Base::Base;
2483
2484 LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2485 PatternRewriter &rewriter) const override {
2486 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2487 if (!bcastOp)
2488 return failure();
2489
2490 // Only handle broadcasts from a vector source here.
2491 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2492 if (!srcType)
2493 return failure();
2494
2495 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2496
2497 ArrayRef<int64_t> dstShape = dstType.getShape();
2498 ArrayRef<int64_t> srcShape = srcType.getShape();
2499
2500 int64_t dstRank = dstShape.size();
2501 int64_t srcRank = srcShape.size();
2502
2503 // Create elements for the broadcast source vector.
2504 auto srcElems = vector::ToElementsOp::create(
2505 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2506
2507 int64_t dstCount = llvm::product_of(dstShape);
2508
2509 SmallVector<Value> replacements;
2510 replacements.reserve(dstCount);
2511
2512 // For each element of the destination, determine which element of the
2513 // source should be used. We walk all destination positions using a single
2514 // counter, decode it into per-dimension indices, then build the matching
2515 // source position: use the same index where sizes match, and use 0 where
2516 // the source size is 1 (replication). This mapping is needed so we can
2517 // replace each result of to_elements with the corresponding element from
2518 // the broadcast source.
2519 // Inner-dimension stretch example:
2520 // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2521 // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2522 // becomes:
2523 // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2524 // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2525 // // %src_elems#1, %src_elems#0, %src_elems#1,
2526 // // %src_elems#2, %src_elems#3, %src_elems#2,
2527 // // %src_elems#3, %src_elems#2, %src_elems#3
2528
2529 // Row-major strides for the destination shape.
2530 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2531 // Row-major strides for the source shape.
2532 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2533 SmallVector<int64_t> dstIdx(dstRank);
2534 SmallVector<int64_t> srcIdx(srcRank);
2535 for (int64_t lin = 0; lin < dstCount; ++lin) {
2536 // Convert linear destination index to per-dimension indices.
2537 dstIdx = delinearize(lin, dstStrides);
2538 for (int64_t k = 0; k < srcRank; ++k)
2539 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2540 // Convert per-dimension source indices back to a linear index.
2541 int64_t srcLin = linearize(srcIdx, srcStrides);
2542 replacements.push_back(srcElems.getResult(srcLin));
2543 }
2544
2545 rewriter.replaceOp(toElementsOp, replacements);
2546 return success();
2547 }
2548};
2549
2550void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2551 MLIRContext *context) {
2552 results.add<ToElementsOfBroadcast>(context);
2553}
2554
2555//===----------------------------------------------------------------------===//
2556// FromElementsOp
2557//===----------------------------------------------------------------------===//
2558
2559/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2560///
2561/// Case #1: Input and output vectors are the same.
2562///
2563/// %0:3 = vector.to_elements %a : vector<3xf32>
2564/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2565/// user_op %1
2566///
2567/// becomes:
2568///
2569/// user_op %a
2570///
2571static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2572 OperandRange fromElemsOperands = fromElementsOp.getElements();
2573 if (fromElemsOperands.empty())
2574 return {};
2575
2576 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2577 if (!toElementsOp)
2578 return {};
2579
2580 if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2581 return {};
2582
2583 // Case #1: Input and output vectors are the same. Forward the input vector.
2584 Value toElementsInput = toElementsOp.getSource();
2585 if (fromElementsOp.getType() == toElementsInput.getType() &&
2586 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2587 return toElementsInput;
2588 }
2589
2590 // TODO: Support cases with different input and output shapes and different
2591 // number of elements.
2592
2593 return {};
2594}
2595
2596/// Fold vector.from_elements to a constant when all operands are constants.
2597/// Example:
2598/// %c1 = arith.constant 1 : i32
2599/// %c2 = arith.constant 2 : i32
2600/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2601/// =>
2602/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2603///
2604static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2605 ArrayRef<Attribute> elements) {
2606 // Check for null or poison attributes before any processing.
2607 if (llvm::any_of(elements, [](Attribute attr) {
2608 return !attr || isa<ub::PoisonAttrInterface>(attr);
2609 }))
2610 return {};
2611
2612 // DenseElementsAttr only supports int/index/float/complex types.
2613 auto destVecType = fromElementsOp.getDest().getType();
2614 auto destEltType = destVecType.getElementType();
2615 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2616 return {};
2617
2618 // Constant attributes might have a different type than the return type.
2619 // Convert them before creating the dense elements attribute.
2620 auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2621 return convertNumericAttr(attr, destEltType);
2622 });
2623
2624 return DenseElementsAttr::get(destVecType, convertedElements);
2625}
2626
2627OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2628 if (auto res = foldFromElementsToElements(*this))
2629 return res;
2630 if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2631 return res;
2632
2633 return {};
2634}
2635
2636/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2637/// same. Example:
2638/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2639/// =>
2640/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2641static LogicalResult
2642rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2643 PatternRewriter &rewriter) {
2644 if (!llvm::all_equal(fromElementsOp.getElements()))
2645 return failure();
2646 rewriter.replaceOpWithNewOp<BroadcastOp>(
2647 fromElementsOp, fromElementsOp.getType(),
2648 fromElementsOp.getElements().front());
2649 return success();
2650}
2651
2652/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2653/// on a single extract. Example:
2654/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2655/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2656/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2657///
2658/// becomes
2659/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2660/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2661///
2662/// The requirements for this to be valid are
2663///
2664/// i) The elements are extracted from the same vector (%source).
2665///
2666/// ii) The elements form a suffix of %source. Specifically, the number
2667/// of elements is the same as the product of the last N dimension sizes
2668/// of %source, for some N.
2669///
2670/// iii) The elements are extracted contiguously in ascending order.
2671
2672class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2673
2674 using Base::Base;
2675
2676 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2677 PatternRewriter &rewriter) const override {
2678
2679 // Handled by `rewriteFromElementsAsBroadcast`.
2680 if (fromElements.getType().getNumElements() == 1)
2681 return failure();
2682
2683 // The common source that all elements are extracted from, if one exists.
2685 // The position of the combined extract operation, if one is created.
2686 ArrayRef<int64_t> combinedPosition;
2687 // The expected index of extraction of the current element in the loop, if
2688 // elements are extracted contiguously in ascending order.
2689 SmallVector<int64_t> expectedPosition;
2690
2691 for (auto [insertIndex, element] :
2692 llvm::enumerate(fromElements.getElements())) {
2693
2694 // Check that the element is from a vector.extract operation.
2695 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2696 if (!extractOp) {
2697 return rewriter.notifyMatchFailure(fromElements,
2698 "element not from vector.extract");
2699 }
2700
2701 // Check condition (i) by checking that all elements have the same source
2702 // as the first element.
2703 if (insertIndex == 0) {
2704 source = extractOp.getSource();
2705 } else if (extractOp.getSource() != source) {
2706 return rewriter.notifyMatchFailure(fromElements,
2707 "element from different vector");
2708 }
2709
2710 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2711 int64_t rank = position.size();
2712 assert(rank == source.getType().getRank() &&
2713 "scalar extract must have full rank position");
2714
2715 // Check condition (ii) by checking that the position that the first
2716 // element is extracted from has sufficient trailing 0s. For example, in
2717 //
2718 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2719 // [...]
2720 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2721 //
2722 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2723 // elements, which is the number of elements of %n, so this is valid.
2724 if (insertIndex == 0) {
2725 const int64_t numElms = fromElements.getType().getNumElements();
2726 int64_t numSuffixElms = 1;
2727 int64_t index = rank;
2728 while (index > 0 && position[index - 1] == 0 &&
2729 numSuffixElms < numElms) {
2730 numSuffixElms *= source.getType().getDimSize(index - 1);
2731 --index;
2732 }
2733 if (numSuffixElms != numElms) {
2734 return rewriter.notifyMatchFailure(
2735 fromElements, "elements do not form a suffix of source");
2736 }
2737 expectedPosition = llvm::to_vector(position);
2738 combinedPosition = position.drop_back(rank - index);
2739 }
2740
2741 // Check condition (iii).
2742 else if (expectedPosition != position) {
2743 return rewriter.notifyMatchFailure(
2744 fromElements, "elements not in ascending order (static order)");
2745 }
2746 increment(expectedPosition, source.getType().getShape());
2747 }
2748
2749 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2750 fromElements.getLoc(), source, combinedPosition);
2751
2752 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2753 fromElements, fromElements.getType(), extracted);
2754
2755 return success();
2756 }
2757
2758 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2759 static void increment(MutableArrayRef<int64_t> indices,
2761 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2762 indices[dim] += 1;
2763 if (indices[dim] < shape[dim])
2764 break;
2765 indices[dim] = 0;
2766 }
2767 }
2768};
2769
2770void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2771 MLIRContext *context) {
2773 results.add<FromElementsToShapeCast>(context);
2774}
2775
2776//===----------------------------------------------------------------------===//
2777// BroadcastOp
2778//===----------------------------------------------------------------------===//
2779
2780void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2781 SetIntRangeFn setResultRanges) {
2782 setResultRanges(getResult(), argRanges.front());
2783}
2784
2785std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2786 return llvm::to_vector<4>(getResultVectorType().getShape());
2787}
2788
2789/// Return the dimensions of the result vector that were formerly ones in the
2790/// source tensor and thus correspond to "dim-1" broadcasting.
2791static llvm::SetVector<int64_t>
2793 ArrayRef<int64_t> dstShape) {
2794 int64_t rankDiff = dstShape.size() - srcShape.size();
2795 int64_t dstDim = rankDiff;
2797 for (auto [s1, s2] :
2798 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2799 if (s1 != s2) {
2800 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2801 res.insert(dstDim);
2802 }
2803 ++dstDim;
2804 }
2805 return res;
2806}
2807
2808llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2809 // Scalar broadcast is without any unit dim broadcast.
2810 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2811 if (!srcVectorType)
2812 return {};
2813 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2814 getResultVectorType().getShape());
2815}
2816
2817/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2818/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2819/// This requires (and asserts) that the broadcast is free of "dim-1"
2820/// broadcasting.
2821/// Since vector.broadcast only allows expanding leading dimensions, an extra
2822/// vector.transpose may be inserted to make the broadcast possible.
2823/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2824/// the helper will assert. This means:
2825/// 1. `dstShape` must not be empty.
2826/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2827/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2828// must match the `value` shape.
2829Value BroadcastOp::createOrFoldBroadcastOp(
2830 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2831 const llvm::SetVector<int64_t> &broadcastedDims) {
2832 assert(!dstShape.empty() && "unexpected empty dst shape");
2833
2834 // Well-formedness check.
2835 SmallVector<int64_t> checkShape;
2836 for (int i = 0, e = dstShape.size(); i < e; ++i) {
2837 if (broadcastedDims.contains(i))
2838 continue;
2839 checkShape.push_back(dstShape[i]);
2840 }
2841 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2842 "ill-formed broadcastedDims contains values not confined to "
2843 "destVectorShape");
2844
2845 Location loc = value.getLoc();
2846 Type elementType = getElementTypeOrSelf(value.getType());
2847 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2848 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2849
2850 // Step 2. If scalar -> dstShape broadcast, just do it.
2851 if (!srcVectorType) {
2852 assert(checkShape.empty() &&
2853 "ill-formed createOrFoldBroadcastOp arguments");
2854 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2855 }
2856
2857 assert(srcVectorType.getShape().equals(checkShape) &&
2858 "ill-formed createOrFoldBroadcastOp arguments");
2859
2860 // Step 3. Since vector.broadcast only allows creating leading dims,
2861 // vector -> dstShape broadcast may require a transpose.
2862 // Traverse the dims in order and construct:
2863 // 1. The leading entries of the broadcastShape that is guaranteed to be
2864 // achievable by a simple broadcast.
2865 // 2. The induced permutation for the subsequent vector.transpose that will
2866 // bring us from `broadcastShape` back to he desired `dstShape`.
2867 // If the induced permutation is not the identity, create a vector.transpose.
2868 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2869 broadcastShape.reserve(dstShape.size());
2870 // Consider the example:
2871 // srcShape = 2x4
2872 // dstShape = 1x2x3x4x5
2873 // broadcastedDims = [0, 2, 4]
2874 //
2875 // We want to build:
2876 // broadcastShape = 1x3x5x2x4
2877 // permutation = [0, 2, 4, 1, 3]
2878 // ---V--- -----V-----
2879 // leading broadcast part src shape part
2880 //
2881 // Note that the trailing dims of broadcastShape are exactly the srcShape
2882 // by construction.
2883 // nextSrcShapeDim is used to keep track of where in the permutation the
2884 // "src shape part" occurs.
2885 int64_t nextSrcShapeDim = broadcastedDims.size();
2886 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2887 if (broadcastedDims.contains(i)) {
2888 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2889 // bring it to the head of the broadcastShape.
2890 // It will need to be permuted back from `broadcastShape.size() - 1` into
2891 // position `i`.
2892 broadcastShape.push_back(dstShape[i]);
2893 permutation[i] = broadcastShape.size() - 1;
2894 } else {
2895 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2896 // shape and needs to be permuted into position `i`.
2897 // Don't touch `broadcastShape` here, the whole srcShape will be
2898 // appended after.
2899 permutation[i] = nextSrcShapeDim++;
2900 }
2901 }
2902 // 3.c. Append the srcShape.
2903 llvm::append_range(broadcastShape, srcVectorType.getShape());
2904
2905 // Ensure there are no "dim-1" broadcasts.
2906 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2907 .empty() &&
2908 "unexpected \"dim-1\" broadcast");
2909
2910 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2911 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2912 vector::BroadcastableToResult::Success &&
2913 "must be broadcastable");
2914 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2915 // Step 4. If we find any dimension that indeed needs to be permuted,
2916 // immediately return a new vector.transpose.
2917 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2918 if (permutation[i] != i)
2919 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2920 // Otherwise return res.
2921 return res;
2922}
2923
2925 Type srcType, VectorType dstVectorType,
2926 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2927 // Broadcast scalar to vector of the same element type.
2928 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2929 srcType == getElementTypeOrSelf(dstVectorType))
2931 // From now on, only vectors broadcast.
2932 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2933 if (!srcVectorType)
2935
2936 int64_t srcRank = srcVectorType.getRank();
2937 int64_t dstRank = dstVectorType.getRank();
2938 if (srcRank > dstRank)
2940 // Source has an exact match or singleton value for all trailing dimensions
2941 // (all leading dimensions are simply duplicated).
2942 int64_t lead = dstRank - srcRank;
2943 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2944 // Have mismatching dims (in the sense of vector.broadcast semantics) been
2945 // encountered?
2946 bool foundMismatchingDims = false;
2947
2948 // Check fixed-width dims.
2949 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2950 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2951 if (srcDim != 1 && srcDim != dstDim)
2952 foundMismatchingDims = true;
2953
2954 // Check scalable flags.
2955 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2956 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2957 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2958 // 1 -> [N] is fine, everything else should be rejected when mixing
2959 // fixed-width and scalable dims
2960 (srcDimScalableFlag != dstDimScalableFlag &&
2961 (srcDim != 1 || srcDimScalableFlag)))
2962 foundMismatchingDims = true;
2963
2964 if (foundMismatchingDims) {
2965 if (mismatchingDims != nullptr) {
2966 mismatchingDims->first.dim = srcDim;
2967 mismatchingDims->first.isScalable = srcDimScalableFlag;
2968
2969 mismatchingDims->second.dim = dstDim;
2970 mismatchingDims->second.isScalable = dstDimScalableFlag;
2971 }
2973 }
2974 }
2975
2977}
2978
2979LogicalResult BroadcastOp::verify() {
2980 std::pair<VectorDim, VectorDim> mismatchingDims;
2982 getSourceType(), getResultVectorType(), &mismatchingDims);
2984 return success();
2986 return emitOpError("source rank higher than destination rank");
2988 return emitOpError("dimension mismatch (")
2989 << (mismatchingDims.first.isScalable ? "[" : "")
2990 << mismatchingDims.first.dim
2991 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2992 << (mismatchingDims.second.isScalable ? "[" : "")
2993 << mismatchingDims.second.dim
2994 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2995 }
2997 return emitOpError("source type is not a vector");
2998 llvm_unreachable("unexpected vector.broadcast op error");
2999}
3000
3001// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
3002// with broadcast's result type and shape_cast only adds or removes ones in the
3003// leading dimensions.
3004static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
3005 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3006 if (!srcShapeCast)
3007 return failure();
3008
3009 VectorType srcType = srcShapeCast.getSourceVectorType();
3010 VectorType destType = broadcastOp.getResultVectorType();
3011 // Check type compatibility.
3012 if (vector::isBroadcastableTo(srcType, destType) !=
3014 return failure();
3015
3016 ArrayRef<int64_t> srcShape = srcType.getShape();
3017 ArrayRef<int64_t> shapecastShape =
3018 srcShapeCast.getResultVectorType().getShape();
3019 // Trailing dimensions should be the same if shape_cast only alters the
3020 // leading dimensions.
3021 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3022 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3023 shapecastShape.take_back(numTrailingDims)))
3024 return failure();
3025
3026 assert(all_of(srcShape.drop_back(numTrailingDims),
3027 [](int64_t E) { return E == 1; }) &&
3028 all_of(shapecastShape.drop_back(numTrailingDims),
3029 [](int64_t E) { return E == 1; }) &&
3030 "ill-formed shape_cast");
3031
3032 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3033 return success();
3034}
3035
3036OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3037 if (getSourceType() == getResultVectorType())
3038 return getSource();
3039 if (succeeded(foldBroadcastOfShapeCast(*this)))
3040 return getResult();
3041
3042 if (!adaptor.getSource())
3043 return {};
3044 auto vectorType = getResultVectorType();
3045 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3046 if (vectorType.getElementType() != attr.getType())
3047 return {};
3048 return DenseElementsAttr::get(vectorType, attr);
3049 }
3050 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3051 if (vectorType.getElementType() != attr.getType())
3052 return {};
3053 return DenseElementsAttr::get(vectorType, attr);
3054 }
3055 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3056 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3057 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
3058 return ub::PoisonAttr::get(getContext());
3059 return {};
3060}
3061
3062namespace {
3063
3064// Fold broadcast1(broadcast2(x)) into broadcast1(x).
3065struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3066 using Base::Base;
3067
3068 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3069 PatternRewriter &rewriter) const override {
3070 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3071 if (!srcBroadcast)
3072 return failure();
3073 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3074 broadcastOp.getResultVectorType(),
3075 srcBroadcast.getSource());
3076 return success();
3077 }
3078};
3079} // namespace
3080
3081void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3082 MLIRContext *context) {
3083 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
3084 // calling `populateCastAwayVectorLeadingOneDimPatterns`
3085 results.add<BroadcastFolder>(context);
3086}
3087
3088//===----------------------------------------------------------------------===//
3089// ShuffleOp
3090//===----------------------------------------------------------------------===//
3091
3092LogicalResult ShuffleOp::verify() {
3093 VectorType resultType = getResultVectorType();
3094 VectorType v1Type = getV1VectorType();
3095 VectorType v2Type = getV2VectorType();
3096 // Verify ranks.
3097 int64_t resRank = resultType.getRank();
3098 int64_t v1Rank = v1Type.getRank();
3099 int64_t v2Rank = v2Type.getRank();
3100 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3101 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3102 if (!wellFormed0DCase && !wellFormedNDCase)
3103 return emitOpError("rank mismatch");
3104
3105 // Verify all but leading dimension sizes.
3106 for (int64_t r = 1; r < v1Rank; ++r) {
3107 int64_t resDim = resultType.getDimSize(r);
3108 int64_t v1Dim = v1Type.getDimSize(r);
3109 int64_t v2Dim = v2Type.getDimSize(r);
3110 if (resDim != v1Dim || v1Dim != v2Dim)
3111 return emitOpError("dimension mismatch");
3112 }
3113 // Verify mask length.
3114 ArrayRef<int64_t> mask = getMask();
3115 int64_t maskLength = mask.size();
3116 if (maskLength <= 0)
3117 return emitOpError("invalid mask length");
3118 if (maskLength != resultType.getDimSize(0))
3119 return emitOpError("mask length mismatch");
3120 // Verify all indices.
3121 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3122 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3123 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3124 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3125 return emitOpError("mask index #") << (idx + 1) << " out of range";
3126 }
3127 return success();
3128}
3129
3130LogicalResult
3131ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
3132 ShuffleOp::Adaptor adaptor,
3133 SmallVectorImpl<Type> &inferredReturnTypes) {
3134 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
3135 auto v1Rank = v1Type.getRank();
3136 // Construct resulting type: leading dimension matches mask
3137 // length, all trailing dimensions match the operands.
3138 SmallVector<int64_t, 4> shape;
3139 shape.reserve(v1Rank);
3140 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3141 // In the 0-D case there is no trailing shape to append.
3142 if (v1Rank > 0)
3143 llvm::append_range(shape, v1Type.getShape().drop_front());
3144 inferredReturnTypes.push_back(
3145 VectorType::get(shape, v1Type.getElementType()));
3146 return success();
3147}
3148
3149template <typename T>
3150static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3151 T expected = begin;
3152 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3153 return value == expected++;
3154 });
3155}
3156
3157OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3158 auto v1Type = getV1VectorType();
3159 auto v2Type = getV2VectorType();
3160
3161 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3162 "Vector shuffle does not support scalable vectors");
3163
3164 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3165 // but must be a canonicalization into a vector.broadcast.
3166 if (v1Type.getRank() == 0)
3167 return {};
3168
3169 // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3170 auto mask = getMask();
3171 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3172 return getV1();
3173 // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3174 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3175 return getV2();
3176
3177 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3178 if (!v1Attr || !v2Attr)
3179 return {};
3180
3181 // Fold shuffle poison, poison -> poison.
3182 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3183 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3184 if (isV1Poison && isV2Poison)
3185 return ub::PoisonAttr::get(getContext());
3186
3187 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
3188 // manipulation.
3189 if (v1Type.getRank() != 1)
3190 return {};
3191
3192 // Poison input attributes need special handling as they are not
3193 // DenseElementsAttr. If an index is poison, we select the first element of
3194 // the first non-poison input.
3195 SmallVector<Attribute> v1Elements, v2Elements;
3196 Attribute poisonElement;
3197 if (!isV2Poison) {
3198 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3199 if (!v2DenseAttr)
3200 return {};
3201 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3202 poisonElement = v2Elements[0];
3203 }
3204 if (!isV1Poison) {
3205 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3206 if (!v1DenseAttr)
3207 return {};
3208 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3209 poisonElement = v1Elements[0];
3210 }
3211
3212 SmallVector<Attribute> results;
3213 int64_t v1Size = v1Type.getDimSize(0);
3214 for (int64_t maskIdx : mask) {
3215 Attribute indexedElm;
3216 // TODO: Return a partial poison vector when supported by the UB dialect.
3217 if (maskIdx == ShuffleOp::kPoisonIndex) {
3218 indexedElm = poisonElement;
3219 } else {
3220 if (maskIdx < v1Size)
3221 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3222 else
3223 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3224 }
3225
3226 results.push_back(indexedElm);
3227 }
3228
3229 return DenseElementsAttr::get(getResultVectorType(), results);
3230}
3231
3232namespace {
3233
3234// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3235// to a broadcast.
3236struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3237 using Base::Base;
3238
3239 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3240 PatternRewriter &rewriter) const override {
3241 VectorType v1VectorType = shuffleOp.getV1VectorType();
3242 ArrayRef<int64_t> mask = shuffleOp.getMask();
3243 if (v1VectorType.getRank() > 0)
3244 return failure();
3245 if (mask.size() != 1)
3246 return failure();
3247 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3248 if (mask[0] == 0)
3249 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3250 shuffleOp.getV1());
3251 else
3252 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3253 shuffleOp.getV2());
3254 return success();
3255 }
3256};
3257
3258/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3259/// vector.broadcast with a scalar operand, return the scalar value that is
3260/// splatted. Otherwise return null.
3261///
3262/// Example:
3263///
3264/// scalar_source --> vector.broadcast --> value - return scalar_source
3265static Value getScalarSplatSource(Value value) {
3266 // Block argument:
3267 Operation *defOp = value.getDefiningOp();
3268 if (!defOp)
3269 return {};
3270
3271 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3272
3273 // Not broadcast (and not splat):
3274 if (!broadcast)
3275 return {};
3276
3277 // Broadcast of a vector:
3278 if (isa<VectorType>(broadcast.getSourceType()))
3279 return {};
3280
3281 // Broadcast of a scalar:
3282 return broadcast.getSource();
3283}
3284
3285/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3286class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3287public:
3288 using Base::Base;
3289
3290 LogicalResult matchAndRewrite(ShuffleOp op,
3291 PatternRewriter &rewriter) const override {
3292 Value splat = getScalarSplatSource(op.getV1());
3293 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3294 return failure();
3295
3296 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3297 return success();
3298 }
3299};
3300
3301/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3302/// vector.interleave.
3303class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3304public:
3305 using Base::Base;
3306
3307 LogicalResult matchAndRewrite(ShuffleOp op,
3308 PatternRewriter &rewriter) const override {
3309 VectorType resultType = op.getResultVectorType();
3310 if (resultType.isScalable())
3311 return rewriter.notifyMatchFailure(
3312 op, "ShuffleOp can't represent a scalable interleave");
3313
3314 if (resultType.getRank() != 1)
3315 return rewriter.notifyMatchFailure(
3316 op, "ShuffleOp can't represent an n-D interleave");
3317
3318 VectorType sourceType = op.getV1VectorType();
3319 if (sourceType != op.getV2VectorType() ||
3320 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3321 return rewriter.notifyMatchFailure(
3322 op, "ShuffleOp types don't match an interleave");
3323 }
3324
3325 ArrayRef<int64_t> shuffleMask = op.getMask();
3326 int64_t resultVectorSize = resultType.getNumElements();
3327 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3328 int64_t maskValueA = shuffleMask[i * 2];
3329 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3330 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3331 return rewriter.notifyMatchFailure(op,
3332 "ShuffleOp mask not interleaving");
3333 }
3334
3335 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3336 return success();
3337 }
3338};
3339
3340} // namespace
3341
3342void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3343 MLIRContext *context) {
3344 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3345 context);
3346}
3347
3348//===----------------------------------------------------------------------===//
3349// InsertOp
3350//===----------------------------------------------------------------------===//
3351
3352void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3353 SetIntRangeFn setResultRanges) {
3354 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3355}
3356
3357void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3358 Value source, Value dest) {
3359 auto vectorTy = cast<VectorType>(dest.getType());
3360 build(builder, result, source, dest,
3361 SmallVector<int64_t>(vectorTy.getRank(), 0));
3362}
3363
3364void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3365 Value source, Value dest, int64_t position) {
3366 build(builder, result, source, dest, ArrayRef<int64_t>{position});
3367}
3368
3369void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3370 Value source, Value dest, OpFoldResult position) {
3371 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3372}
3373
3374void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3375 Value source, Value dest,
3376 ArrayRef<int64_t> position) {
3377 SmallVector<OpFoldResult> posVals;
3378 posVals.reserve(position.size());
3379 llvm::transform(position, std::back_inserter(posVals),
3380 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3381 build(builder, result, source, dest, posVals);
3382}
3383
3384void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3385 Value source, Value dest,
3386 ArrayRef<OpFoldResult> position) {
3387 SmallVector<int64_t> staticPos;
3388 SmallVector<Value> dynamicPos;
3389 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3390 build(builder, result, source, dest, dynamicPos,
3391 builder.getDenseI64ArrayAttr(staticPos));
3392}
3393
3394LogicalResult InsertOp::verify() {
3395 if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3396 if (srcTy.getRank() == 0)
3397 return emitError(
3398 "expected a scalar instead of a 0-d vector as the source operand");
3399
3400 SmallVector<OpFoldResult> position = getMixedPosition();
3401 auto destVectorType = getDestVectorType();
3402 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3403 return emitOpError(
3404 "expected position attribute of rank no greater than dest vector rank");
3405 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3406 if (srcVectorType &&
3407 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3408 static_cast<unsigned>(destVectorType.getRank())))
3409 return emitOpError("expected position attribute rank + source rank to "
3410 "match dest vector rank");
3411 if (!srcVectorType &&
3412 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3413 return emitOpError(
3414 "expected position attribute rank to match the dest vector rank");
3415 for (auto [idx, pos] : llvm::enumerate(position)) {
3416 if (auto attr = dyn_cast<Attribute>(pos)) {
3417 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3418 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3419 destVectorType.getDimSize(idx))) {
3420 return emitOpError("expected position attribute #")
3421 << (idx + 1)
3422 << " to be a non-negative integer smaller than the "
3423 "corresponding "
3424 "dest vector dimension";
3425 }
3426 }
3427 }
3428 return success();
3429}
3430
3431// Calculate the linearized position of the continuous chunk of elements to
3432// insert, based on the shape of the value to insert and the positions to insert
3433// at.
3434static int64_t calculateInsertPosition(VectorType destTy,
3435 ArrayRef<int64_t> positions) {
3436 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3437 assert(positions.size() <= completePositions.size() &&
3438 "positions size must be less than or equal to destTy rank");
3439 copy(positions, completePositions.begin());
3440 return linearize(completePositions, computeStrides(destTy.getShape()));
3441}
3442
3443namespace {
3444
3445// If insertOp is only inserting unit dimensions it can be transformed to a
3446// broadcast.
3447class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3448public:
3449 using Base::Base;
3450
3451 LogicalResult matchAndRewrite(InsertOp insertOp,
3452 PatternRewriter &rewriter) const override {
3453 auto srcVecType =
3454 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3455 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3456 srcVecType.getNumElements())
3457 return failure();
3458 rewriter.replaceOpWithNewOp<BroadcastOp>(
3459 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3460 return success();
3461 }
3462};
3463
3464/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3465class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3466public:
3467 using Base::Base;
3468
3469 LogicalResult matchAndRewrite(InsertOp op,
3470 PatternRewriter &rewriter) const override {
3471
3472 Value splat = getScalarSplatSource(op.getValueToStore());
3473 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3474 return failure();
3475
3476 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3477 return success();
3478 }
3479};
3480
3481/// Pattern to optimize a chain of insertions.
3482///
3483/// This pattern identifies chains of vector.insert operations that:
3484/// 1. Only insert values at static positions.
3485/// 2. Completely initialize all elements in the resulting vector.
3486/// 3. All intermediate insert operations have only one use.
3487///
3488/// When these conditions are met, the entire chain can be replaced with a
3489/// single vector.from_elements operation.
3490///
3491/// To keep this pattern simple, and avoid spending too much time on matching
3492/// fragmented insert chains, this pattern only considers the last insert op in
3493/// the chain.
3494///
3495/// Example transformation:
3496/// %poison = ub.poison : vector<2xi32>
3497/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3498/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3499/// ->
3500/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3501class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3502public:
3503 using Base::Base;
3504 LogicalResult matchAndRewrite(InsertOp op,
3505 PatternRewriter &rewriter) const override {
3506
3507 VectorType destTy = op.getDestVectorType();
3508 if (destTy.isScalable())
3509 return failure();
3510 // Ensure this is the trailing vector.insert op in a chain of inserts.
3511 for (Operation *user : op.getResult().getUsers())
3512 if (auto insertOp = dyn_cast<InsertOp>(user))
3513 if (insertOp.getDest() == op.getResult())
3514 return failure();
3515
3516 InsertOp currentOp = op;
3517 SmallVector<InsertOp> chainInsertOps;
3518 while (currentOp) {
3519 // Check cond 1: Dynamic position is not supported.
3520 if (currentOp.hasDynamicPosition())
3521 return failure();
3522
3523 chainInsertOps.push_back(currentOp);
3524 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3525 // Check cond 3: Intermediate inserts have only one use to avoid an
3526 // explosion of vectors.
3527 if (currentOp && !currentOp->hasOneUse())
3528 return failure();
3529 }
3530
3531 int64_t vectorSize = destTy.getNumElements();
3532 int64_t initializedCount = 0;
3533 SmallVector<bool> initializedDestIdxs(vectorSize, false);
3534 SmallVector<int64_t> pendingInsertPos;
3535 SmallVector<int64_t> pendingInsertSize;
3536 SmallVector<Value> pendingInsertValues;
3537
3538 for (auto insertOp : chainInsertOps) {
3539 // This pattern can do nothing with poison index.
3540 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3541 return failure();
3542
3543 // Calculate the linearized position for inserting elements.
3544 int64_t insertBeginPosition =
3545 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3546
3547 // The valueToStore operand may be a vector or a scalar. Need to handle
3548 // both cases.
3549 int64_t insertSize = 1;
3550 if (auto srcVectorType =
3551 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3552 insertSize = srcVectorType.getNumElements();
3553
3554 assert(insertBeginPosition + insertSize <= vectorSize &&
3555 "insert would overflow the vector");
3556
3557 for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3558 insertBeginPosition + insertSize)) {
3559 if (initializedDestIdxs[index])
3560 continue;
3561 initializedDestIdxs[index] = true;
3562 ++initializedCount;
3563 }
3564
3565 // Defer the creation of ops before we can make sure the pattern can
3566 // succeed.
3567 pendingInsertPos.push_back(insertBeginPosition);
3568 pendingInsertSize.push_back(insertSize);
3569 pendingInsertValues.push_back(insertOp.getValueToStore());
3570
3571 if (initializedCount == vectorSize)
3572 break;
3573 }
3574
3575 // Check cond 2: all positions must be initialized.
3576 if (initializedCount != vectorSize)
3577 return failure();
3578
3579 SmallVector<Value> elements(vectorSize);
3580 for (auto [insertBeginPosition, insertSize, valueToStore] :
3581 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3582 pendingInsertValues))) {
3583 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3584
3585 if (!srcVectorType) {
3586 elements[insertBeginPosition] = valueToStore;
3587 continue;
3588 }
3589
3590 SmallVector<Type> elementToInsertTypes(insertSize,
3591 srcVectorType.getElementType());
3592 // Get all elements from the vector in row-major order.
3593 auto elementsToInsert = vector::ToElementsOp::create(
3594 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3595 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3596 elements[insertBeginPosition + linearIdx] =
3597 elementsToInsert.getResult(linearIdx);
3598 }
3599 }
3600
3601 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3602 return success();
3603 }
3604};
3605
3606} // namespace
3607
3608static Attribute
3610 Attribute dstAttr,
3611 int64_t maxVectorSizeFoldThreshold) {
3612 if (insertOp.hasDynamicPosition())
3613 return {};
3614
3615 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3616 if (!denseDst)
3617 return {};
3618
3619 if (!srcAttr) {
3620 return {};
3621 }
3622
3623 VectorType destTy = insertOp.getDestVectorType();
3624 if (destTy.isScalable())
3625 return {};
3626
3627 // Make sure we do not create too many large constants.
3628 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3629 !insertOp->hasOneUse())
3630 return {};
3631
3632 // Calculate the linearized position for inserting elements.
3633 int64_t insertBeginPosition =
3634 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3635 SmallVector<Attribute> insertedValues;
3636 Type destEltType = destTy.getElementType();
3637
3638 /// Converts attribute to the expected type if there's
3639 /// a mismatch.
3640 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3641 for (auto value : denseSource.getValues<Attribute>())
3642 insertedValues.push_back(convertNumericAttr(value, destEltType));
3643 } else {
3644 insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3645 }
3646
3647 auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3648 copy(insertedValues, allValues.begin() + insertBeginPosition);
3649 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3650
3651 return newAttr;
3652}
3653
3654/// Folder to replace the `dest` operand of the insert op with the root dest of
3655/// the insert op use chain.
3656static Value foldInsertUseChain(InsertOp insertOp) {
3657 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3658 if (!destInsert)
3659 return {};
3660
3661 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3662 return {};
3663
3664 insertOp.setOperand(1, destInsert.getDest());
3665 return insertOp.getResult();
3666}
3667
3668void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3669 MLIRContext *context) {
3670 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3671 InsertChainFullyInitialized>(context);
3672}
3673
3674OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3675 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3676 // unless the source vector constant has a single use.
3677 constexpr int64_t vectorSizeFoldThreshold = 256;
3678 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3679 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3680 // (type mismatch).
3681 if (getNumIndices() == 0 && getValueToStoreType() == getType())
3682 return getValueToStore();
3683 // Fold `arith.constant` indices into the `vector.insert` operation.
3684 // Do not stop here as this fold may enable subsequent folds that require
3685 // constant indices.
3686 SmallVector<Value> operands = {getValueToStore(), getDest()};
3687 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3688
3689 if (auto res = foldInsertUseChain(*this))
3690 return res;
3691 if (auto res = foldPoisonIndexInsertExtractOp(
3692 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3693 return res;
3694 if (auto res = foldDenseElementsAttrDestInsertOp(
3695 *this, adaptor.getValueToStore(), adaptor.getDest(),
3696 vectorSizeFoldThreshold)) {
3697 return res;
3698 }
3699
3700 return inplaceFolded;
3701}
3702
3703//===----------------------------------------------------------------------===//
3704// InsertStridedSliceOp
3705//===----------------------------------------------------------------------===//
3706
3707void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3708 Value source, Value dest,
3709 ArrayRef<int64_t> offsets,
3710 ArrayRef<int64_t> strides) {
3711 result.addOperands({source, dest});
3712 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3713 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3714 result.addTypes(dest.getType());
3715 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3716 offsetsAttr);
3717 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3718 stridesAttr);
3719}
3720
3721// TODO: Should be moved to Tablegen ConfinedAttr attributes.
3722template <typename OpType>
3723static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3724 ArrayAttr arrayAttr,
3726 StringRef attrName) {
3727 if (arrayAttr.size() > shape.size())
3728 return op.emitOpError("expected ")
3729 << attrName << " attribute of rank no greater than vector rank";
3730 return success();
3731}
3732
3733// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3734// interval. If `halfOpen` is true then the admissible interval is [min, max).
3735// Otherwise, the admissible interval is [min, max].
3736template <typename OpType>
3737static LogicalResult
3739 int64_t max, StringRef attrName,
3740 bool halfOpen = true) {
3741 for (auto attr : arrayAttr) {
3742 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3743 auto upper = max;
3744 if (!halfOpen)
3745 upper += 1;
3746 if (val < min || val >= upper)
3747 return op.emitOpError("expected ") << attrName << " to be confined to ["
3748 << min << ", " << upper << ")";
3749 }
3750 return success();
3751}
3752
3753// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3754// interval. If `halfOpen` is true then the admissible interval is [min, max).
3755// Otherwise, the admissible interval is [min, max].
3756template <typename OpType>
3757static LogicalResult
3759 ArrayRef<int64_t> shape, StringRef attrName,
3760 bool halfOpen = true, int64_t min = 0) {
3761 for (auto [index, attrDimPair] :
3762 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3763 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3764 int64_t max = std::get<1>(attrDimPair);
3765 if (!halfOpen)
3766 max += 1;
3767 if (val < min || val >= max)
3768 return op.emitOpError("expected ")
3769 << attrName << " dimension " << index << " to be confined to ["
3770 << min << ", " << max << ")";
3771 }
3772 return success();
3773}
3774
3775// Returns true if, for all indices i = 0..shape.size()-1, val is in the
3776// [min, max} interval:
3777// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3778// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3779// the admissible interval is [min, max].
3780template <typename OpType>
3782 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3783 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3784 bool halfOpen = true, int64_t min = 1) {
3785 assert(arrayAttr1.size() <= shape.size());
3786 assert(arrayAttr2.size() <= shape.size());
3787 for (auto [index, it] :
3788 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3789 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3790 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3791 int64_t max = std::get<2>(it);
3792 if (!halfOpen)
3793 max += 1;
3794 if (val1 + val2 < 0 || val1 + val2 >= max)
3795 return op.emitOpError("expected sum(")
3796 << attrName1 << ", " << attrName2 << ") dimension " << index
3797 << " to be confined to [" << min << ", " << max << ")";
3798 }
3799 return success();
3800}
3801
3803 MLIRContext *context) {
3804 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3805 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3806 });
3807 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3808}
3809
3810LogicalResult InsertStridedSliceOp::verify() {
3811 auto sourceVectorType = getSourceVectorType();
3812 auto destVectorType = getDestVectorType();
3813 auto offsets = getOffsetsAttr();
3814 auto strides = getStridesAttr();
3815 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3816 return emitOpError(
3817 "expected offsets of same size as destination vector rank");
3818 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3819 return emitOpError("expected strides of same size as source vector rank");
3820 if (sourceVectorType.getRank() > destVectorType.getRank())
3821 return emitOpError(
3822 "expected source rank to be no greater than destination rank");
3823
3824 auto sourceShape = sourceVectorType.getShape();
3825 auto destShape = destVectorType.getShape();
3826 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3827 destShape.size() - sourceShape.size(), 0);
3828 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3829 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3830 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3831 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3832 offName)) ||
3833 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3834 /*max=*/1, stridesName,
3835 /*halfOpen=*/false)) ||
3837 *this, offsets,
3838 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3839 offName, "source vector shape",
3840 /*halfOpen=*/false, /*min=*/1)))
3841 return failure();
3842
3843 unsigned rankDiff = destShape.size() - sourceShape.size();
3844 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3845 if (sourceVectorType.getScalableDims()[idx] !=
3846 destVectorType.getScalableDims()[idx + rankDiff]) {
3847 return emitOpError("mismatching scalable flags (at source vector idx=")
3848 << idx << ")";
3849 }
3850 if (sourceVectorType.getScalableDims()[idx]) {
3851 auto sourceSize = sourceShape[idx];
3852 auto destSize = destShape[idx + rankDiff];
3853 if (sourceSize != destSize) {
3854 return emitOpError("expected size at idx=")
3855 << idx
3856 << (" to match the corresponding base size from the input "
3857 "vector (")
3858 << sourceSize << (" vs ") << destSize << (")");
3859 }
3860 }
3861 }
3862
3863 return success();
3864}
3865
3866namespace {
3867/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
3868class FoldInsertStridedSliceSplat final
3869 : public OpRewritePattern<InsertStridedSliceOp> {
3870public:
3871 using Base::Base;
3872
3873 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3874 PatternRewriter &rewriter) const override {
3875
3876 auto dst = insertStridedSliceOp.getDest();
3877 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3878 if (!splat || getScalarSplatSource(dst) != splat)
3879 return failure();
3880
3881 rewriter.replaceOp(insertStridedSliceOp, dst);
3882 return success();
3883 }
3884};
3885
3886/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3887/// to dst.
3888class FoldInsertStridedSliceOfExtract final
3889 : public OpRewritePattern<InsertStridedSliceOp> {
3890public:
3891 using Base::Base;
3892
3893 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3894 PatternRewriter &rewriter) const override {
3895 auto extractStridedSliceOp =
3896 insertStridedSliceOp.getValueToStore()
3897 .getDefiningOp<vector::ExtractStridedSliceOp>();
3898
3899 if (!extractStridedSliceOp)
3900 return failure();
3901
3902 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3903 return failure();
3904
3905 // Check if have the same strides and offsets.
3906 if (extractStridedSliceOp.getStrides() !=
3907 insertStridedSliceOp.getStrides() ||
3908 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3909 return failure();
3910
3911 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3912 return success();
3913 }
3914};
3915
3916// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3917// ConstantOp.
3918class InsertStridedSliceConstantFolder final
3919 : public OpRewritePattern<InsertStridedSliceOp> {
3920public:
3921 using Base::Base;
3922
3923 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3924 // unless the source vector constant has a single use.
3925 static constexpr int64_t vectorSizeFoldThreshold = 256;
3926
3927 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3928 PatternRewriter &rewriter) const override {
3929 // Return if 'InsertOp' operand is not defined by a compatible vector
3930 // ConstantOp.
3931 TypedValue<VectorType> destVector = op.getDest();
3932 Attribute vectorDestCst;
3933 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3934 return failure();
3935
3936 VectorType destTy = destVector.getType();
3937 if (destTy.isScalable())
3938 return failure();
3939
3940 // Make sure we do not create too many large constants.
3941 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3942 !destVector.hasOneUse())
3943 return failure();
3944
3945 TypedValue<VectorType> sourceValue = op.getValueToStore();
3946 Attribute sourceCst;
3947 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3948 return failure();
3949
3950 // TODO: Support poison.
3951 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3952 return failure();
3953
3954 // TODO: Handle non-unit strides when they become available.
3955 if (op.hasNonUnitStrides())
3956 return failure();
3957
3958 VectorType sliceVecTy = sourceValue.getType();
3959 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3960 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3961 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3962 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3963
3964 // Calcualte the destination element indices by enumerating all slice
3965 // positions within the destination and linearizing them. The enumeration
3966 // order is lexicographic which yields a sequence of monotonically
3967 // increasing linearized position indices.
3968 // Because the destination may have higher dimensionality then the slice,
3969 // we keep track of two overlapping sets of positions and offsets.
3970 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3971 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3972 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3973 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3974 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3975 MutableArrayRef<int64_t> currSlicePosition(
3976 currDestPosition.begin() + rankDifference, currDestPosition.end());
3977 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3978 offsets.end());
3979 do {
3980 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3981 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3982 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3983 "Invalid slice element");
3984 newValues[linearizedPosition] = *sliceValuesIt;
3985 ++sliceValuesIt;
3986 } while (succeeded(
3987 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3988
3989 auto newAttr = DenseElementsAttr::get(destTy, newValues);
3990 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3991 return success();
3992 }
3993};
3994
3995} // namespace
3996
3997void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3998 RewritePatternSet &results, MLIRContext *context) {
3999 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4000 InsertStridedSliceConstantFolder>(context);
4001}
4002
4003OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4004 if (getSourceVectorType() == getDestVectorType())
4005 return getValueToStore();
4006 return {};
4007}
4008
4009//===----------------------------------------------------------------------===//
4010// OuterProductOp
4011//===----------------------------------------------------------------------===//
4012
4013/// Build an op without mask, use the type of `acc` as the return type.
4014void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4015 Value lhs, Value rhs, Value acc) {
4016 result.addOperands({lhs, rhs, acc});
4017 result.addTypes(acc.getType());
4018}
4019
4020void OuterProductOp::print(OpAsmPrinter &p) {
4021 p << " " << getLhs() << ", " << getRhs();
4022 if (getAcc()) {
4023 p << ", " << getAcc();
4024 p.printOptionalAttrDict((*this)->getAttrs());
4025 }
4026 p << " : " << getLhs().getType() << ", " << getRhs().getType();
4027}
4028
4029ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4030 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4031 Type tLHS, tRHS;
4032 if (parser.parseOperandList(operandsInfo) ||
4033 parser.parseOptionalAttrDict(result.attributes) ||
4034 parser.parseColonType(tLHS) || parser.parseComma() ||
4035 parser.parseType(tRHS))
4036 return failure();
4037 if (operandsInfo.size() < 2)
4038 return parser.emitError(parser.getNameLoc(),
4039 "expected at least 2 operands");
4040 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4041 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4042 if (!vLHS)
4043 return parser.emitError(parser.getNameLoc(),
4044 "expected vector type for operand #1");
4045
4046 VectorType resType;
4047 if (vRHS) {
4048 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4049 vRHS.getScalableDims()[0]};
4050 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4051 vLHS.getElementType(), scalableDimsRes);
4052 } else {
4053 // Scalar RHS operand
4054 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4055 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4056 scalableDimsRes);
4057 }
4058
4059 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4060 result.attributes.append(
4061 OuterProductOp::getKindAttrName(result.name),
4062 CombiningKindAttr::get(result.getContext(),
4063 OuterProductOp::getDefaultKind()));
4064 }
4065
4066 return failure(
4067 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4068 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4069 (operandsInfo.size() > 2 &&
4070 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4071 parser.addTypeToList(resType, result.types));
4072}
4073
4074LogicalResult OuterProductOp::verify() {
4075 Type tRHS = getOperandTypeRHS();
4076 VectorType vLHS = getOperandVectorTypeLHS(),
4077 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4078 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4079
4080 if (vLHS.getRank() != 1)
4081 return emitOpError("expected 1-d vector for operand #1");
4082
4083 if (vRHS) {
4084 // Proper OUTER operation.
4085 if (vRHS.getRank() != 1)
4086 return emitOpError("expected 1-d vector for operand #2");
4087 if (vRES.getRank() != 2)
4088 return emitOpError("expected 2-d vector result");
4089 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4090 return emitOpError("expected #1 operand dim to match result dim #1");
4091 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4092 return emitOpError("expected #2 operand dim to match result dim #2");
4093 if (vLHS.isScalable() && !vRHS.isScalable()) {
4094 // This restriction reflects what's currently supported in terms of
4095 // scalable vectors. However, we could relax this if there's a use case.
4096 return emitOpError(
4097 "expected either both or only #2 operand dim to be scalable");
4098 }
4099 } else {
4100 // An AXPY operation.
4101 if (vRES.getRank() != 1)
4102 return emitOpError("expected 1-d vector result");
4103 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4104 return emitOpError("expected #1 operand dim to match result dim #1");
4105 }
4106
4107 if (vACC && vACC != vRES)
4108 return emitOpError("expected operand #3 of same type as result type");
4109
4110 // Verify supported combining kind.
4111 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4112 return emitOpError("unsupported outerproduct type");
4113
4114 return success();
4115}
4116
4117// MaskableOpInterface methods.
4118
4119/// Returns the mask type expected by this operation. Mostly used for
4120/// verification purposes. It requires the operation to be vectorized."
4121Type OuterProductOp::getExpectedMaskType() {
4122 auto vecType = this->getResultVectorType();
4123 return VectorType::get(vecType.getShape(),
4124 IntegerType::get(vecType.getContext(), /*width=*/1),
4125 vecType.getScalableDims());
4126}
4127
4128//===----------------------------------------------------------------------===//
4129// ExtractStridedSliceOp
4130//===----------------------------------------------------------------------===//
4131
4132// Inference works as follows:
4133// 1. Add 'sizes' from prefix of dims in 'offsets'.
4134// 2. Add sizes from 'vectorType' for remaining dims.
4135// Scalable flags are inherited from 'vectorType'.
4136static Type inferStridedSliceOpResultType(VectorType vectorType,
4137 ArrayAttr offsets, ArrayAttr sizes,
4138 ArrayAttr strides) {
4139 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4141 shape.reserve(vectorType.getRank());
4142 unsigned idx = 0;
4143 for (unsigned e = offsets.size(); idx < e; ++idx)
4144 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4145 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4146 shape.push_back(vectorType.getShape()[idx]);
4147
4148 return VectorType::get(shape, vectorType.getElementType(),
4149 vectorType.getScalableDims());
4150}
4151
4152void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4153 Value source, ArrayRef<int64_t> offsets,
4154 ArrayRef<int64_t> sizes,
4155 ArrayRef<int64_t> strides) {
4156 result.addOperands(source);
4157 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4158 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4159 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4160 result.addTypes(
4161 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4162 offsetsAttr, sizesAttr, stridesAttr));
4163 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4164 offsetsAttr);
4165 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4166 sizesAttr);
4167 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4168 stridesAttr);
4169}
4170
4171LogicalResult ExtractStridedSliceOp::verify() {
4172 auto type = getSourceVectorType();
4173 auto offsets = getOffsetsAttr();
4174 auto sizes = getSizesAttr();
4175 auto strides = getStridesAttr();
4176 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4177 return emitOpError(
4178 "expected offsets, sizes and strides attributes of same size");
4179
4180 auto shape = type.getShape();
4181 auto offName = getOffsetsAttrName();
4182 auto sizesName = getSizesAttrName();
4183 auto stridesName = getStridesAttrName();
4184 if (failed(
4185 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4186 failed(
4187 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4188 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4189 stridesName)) ||
4190 failed(
4191 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4192 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4193 /*halfOpen=*/false,
4194 /*min=*/1)) ||
4195 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4196 /*max=*/1, stridesName,
4197 /*halfOpen=*/false)) ||
4198 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4199 shape, offName, sizesName,
4200 /*halfOpen=*/false)))
4201 return failure();
4202
4203 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4204 offsets, sizes, strides);
4205 if (getResult().getType() != resultType)
4206 return emitOpError("expected result type to be ") << resultType;
4207
4208 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4209 if (type.getScalableDims()[idx]) {
4210 auto inputDim = type.getShape()[idx];
4211 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4212 if (inputDim != inputSize)
4213 return emitOpError("expected size at idx=")
4214 << idx
4215 << (" to match the corresponding base size from the input "
4216 "vector (")
4217 << inputSize << (" vs ") << inputDim << (")");
4218 }
4219 }
4220
4221 return success();
4222}
4223
4224// When the source of ExtractStrided comes from a chain of InsertStrided ops try
4225// to use the source of the InsertStrided ops if we can detect that the
4226// extracted vector is a subset of one of the vector inserted.
4227static LogicalResult
4228foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4229 // Helper to extract integer out of ArrayAttr.
4230 auto getElement = [](ArrayAttr array, int idx) {
4231 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4232 };
4233 ArrayAttr extractOffsets = op.getOffsets();
4234 ArrayAttr extractStrides = op.getStrides();
4235 ArrayAttr extractSizes = op.getSizes();
4236 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4237 while (insertOp) {
4238 if (op.getSourceVectorType().getRank() !=
4239 insertOp.getSourceVectorType().getRank())
4240 return failure();
4241 ArrayAttr insertOffsets = insertOp.getOffsets();
4242 ArrayAttr insertStrides = insertOp.getStrides();
4243 // If the rank of extract is greater than the rank of insert, we are likely
4244 // extracting a partial chunk of the vector inserted.
4245 if (extractOffsets.size() > insertOffsets.size())
4246 return failure();
4247 bool patialoverlap = false;
4248 bool disjoint = false;
4249 SmallVector<int64_t, 4> offsetDiffs;
4250 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4251 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4252 return failure();
4253 int64_t start = getElement(insertOffsets, dim);
4254 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4255 int64_t offset = getElement(extractOffsets, dim);
4256 int64_t size = getElement(extractSizes, dim);
4257 // Check if the start of the extract offset is in the interval inserted.
4258 if (start <= offset && offset < end) {
4259 // If the extract interval overlaps but is not fully included we may
4260 // have a partial overlap that will prevent any folding.
4261 if (offset + size > end)
4262 patialoverlap = true;
4263 offsetDiffs.push_back(offset - start);
4264 continue;
4265 }
4266 disjoint = true;
4267 break;
4268 }
4269 // The extract element chunk is a subset of the insert element.
4270 if (!disjoint && !patialoverlap) {
4271 op.setOperand(insertOp.getValueToStore());
4272 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4273 OpBuilder b(op.getContext());
4274 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4275 return success();
4276 }
4277 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4278 // in the insert chain.
4279 if (disjoint)
4280 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4281 else {
4282 // The extracted vector partially overlap the inserted vector, we cannot
4283 // fold.
4284 return failure();
4285 }
4286 }
4287 return failure();
4288}
4289
4290// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4291static OpFoldResult
4293 Attribute foldInput) {
4294
4295 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4296 if (!dense)
4297 return {};
4298
4299 // TODO: Handle non-unit strides when they become available.
4300 if (op.hasNonUnitStrides())
4301 return {};
4302
4303 VectorType sourceVecTy = op.getSourceVectorType();
4304 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4305 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4306
4307 VectorType sliceVecTy = op.getType();
4308 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4309 int64_t rank = sliceVecTy.getRank();
4310
4311 // Expand offsets and sizes to match the vector rank.
4312 SmallVector<int64_t, 4> offsets(rank, 0);
4313 copy(getI64SubArray(op.getOffsets()), offsets.begin());
4314
4315 SmallVector<int64_t, 4> sizes(sourceShape);
4316 copy(getI64SubArray(op.getSizes()), sizes.begin());
4317
4318 // Calculate the slice elements by enumerating all slice positions and
4319 // linearizing them. The enumeration order is lexicographic which yields a
4320 // sequence of monotonically increasing linearized position indices.
4321 const auto denseValuesBegin = dense.value_begin<Attribute>();
4322 SmallVector<Attribute> sliceValues;
4323 sliceValues.reserve(sliceVecTy.getNumElements());
4324 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4325 do {
4326 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4327 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4328 "Invalid index");
4329 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4330 } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4331
4332 assert(static_cast<int64_t>(sliceValues.size()) ==
4333 sliceVecTy.getNumElements() &&
4334 "Invalid number of slice elements");
4335 return DenseElementsAttr::get(sliceVecTy, sliceValues);
4336}
4337
4338OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4339 if (getSourceVectorType() == getResult().getType())
4340 return getSource();
4341 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4342 return getResult();
4343
4344 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4345 if (auto splat =
4346 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4347 return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4348
4349 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4350 return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4351}
4352
4353void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4354 populateFromInt64AttrArray(getOffsets(), results);
4355}
4356
4357namespace {
4358
4359// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4360// CreateMaskOp.
4361//
4362// Example:
4363//
4364// %mask = vector.create_mask %ub : vector<16xi1>
4365// %slice = vector.extract_strided_slice [%offset] [8] [1]
4366//
4367// to
4368//
4369// %new_ub = arith.subi %ub, %offset
4370// %mask = vector.create_mask %new_ub : vector<8xi1>
4371class StridedSliceCreateMaskFolder final
4372 : public OpRewritePattern<ExtractStridedSliceOp> {
4373 using Base::Base;
4374
4375public:
4376 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4377 PatternRewriter &rewriter) const override {
4378 Location loc = extractStridedSliceOp.getLoc();
4379 // Return if 'extractStridedSliceOp' operand is not defined by a
4380 // CreateMaskOp.
4381 auto createMaskOp =
4382 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4383 if (!createMaskOp)
4384 return failure();
4385 // Return if 'extractStridedSliceOp' has non-unit strides.
4386 if (extractStridedSliceOp.hasNonUnitStrides())
4387 return failure();
4388 // Gather constant mask dimension sizes.
4389 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4390 // Gather strided slice offsets and sizes.
4391 SmallVector<int64_t> sliceOffsets;
4392 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4393 sliceOffsets);
4394 SmallVector<int64_t> sliceSizes;
4395 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4396
4397 // Compute slice of vector mask region.
4398 SmallVector<Value> sliceMaskDimSizes;
4399 sliceMaskDimSizes.reserve(maskDimSizes.size());
4400 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4401 // only iterate on the leading dim sizes. The tail accounts for the
4402 // remaining dim sizes.
4403 for (auto [maskDimSize, sliceOffset, sliceSize] :
4404 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4405 // No need to clamp on min/max values, because create_mask has clamping
4406 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4407 // greater than the vector dim size.
4408 IntegerAttr offsetAttr =
4409 rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4410 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4411 Value sliceMaskDimSize =
4412 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4413 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4414 }
4415 // Add unchanged dimensions.
4416 llvm::append_range(
4417 sliceMaskDimSizes,
4418 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4419 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4420 // region.
4421 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4422 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4423 sliceMaskDimSizes);
4424 return success();
4425 }
4426};
4427
4428// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4429// ConstantMaskOp.
4430class StridedSliceConstantMaskFolder final
4431 : public OpRewritePattern<ExtractStridedSliceOp> {
4432public:
4433 using Base::Base;
4434
4435 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4436 PatternRewriter &rewriter) const override {
4437 // Return if 'extractStridedSliceOp' operand is not defined by a
4438 // ConstantMaskOp.
4439 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4440 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4441 if (!constantMaskOp)
4442 return failure();
4443 // Return if 'extractStridedSliceOp' has non-unit strides.
4444 if (extractStridedSliceOp.hasNonUnitStrides())
4445 return failure();
4446 // Gather constant mask dimension sizes.
4447 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4448 // Gather strided slice offsets and sizes.
4449 SmallVector<int64_t> sliceOffsets;
4450 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4451 sliceOffsets);
4452 SmallVector<int64_t> sliceSizes;
4453 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4454
4455 // Compute slice of vector mask region.
4456 SmallVector<int64_t> sliceMaskDimSizes;
4457 sliceMaskDimSizes.reserve(maskDimSizes.size());
4458 for (auto [maskDimSize, sliceOffset, sliceSize] :
4459 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4460 int64_t sliceMaskDimSize = std::max(
4461 static_cast<int64_t>(0),
4462 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4463 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4464 }
4465 // Add unchanged dimensions.
4466 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4467 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4468 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4469 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4470 // region is a conjunction of mask dim intervals).
4471 if (llvm::is_contained(sliceMaskDimSizes, 0))
4472 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4473
4474 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4475 // region.
4476 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4477 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4478 sliceMaskDimSizes);
4479 return success();
4480 }
4481};
4482
4483// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4484// BroadcastOp(ExtractStrideSliceOp).
4485class StridedSliceBroadcast final
4486 : public OpRewritePattern<ExtractStridedSliceOp> {
4487public:
4488 using Base::Base;
4489
4490 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4491 PatternRewriter &rewriter) const override {
4492 auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4493 if (!broadcast)
4494 return failure();
4495 auto srcVecType =
4496 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4497 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4498 auto dstVecType = llvm::cast<VectorType>(op.getType());
4499 unsigned dstRank = dstVecType.getRank();
4500 unsigned rankDiff = dstRank - srcRank;
4501 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4502 // (n -> m with n > m). If they are originally both broadcasted *and*
4503 // sliced, this can be simplified to just broadcasting.
4504 bool needsSlice = false;
4505 for (unsigned i = 0; i < srcRank; i++) {
4506 if (srcVecType.getDimSize(i) != 1 &&
4507 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4508 needsSlice = true;
4509 break;
4510 }
4511 }
4512 Value source = broadcast.getSource();
4513 if (needsSlice) {
4514 SmallVector<int64_t> offsets =
4515 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4516 SmallVector<int64_t> sizes =
4517 getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4518 for (unsigned i = 0; i < srcRank; i++) {
4519 if (srcVecType.getDimSize(i) == 1) {
4520 // In case this dimension was broadcasted *and* sliced, the offset
4521 // and size need to be updated now that there is no broadcast before
4522 // the slice.
4523 offsets[i] = 0;
4524 sizes[i] = 1;
4525 }
4526 }
4527 source = ExtractStridedSliceOp::create(
4528 rewriter, op->getLoc(), source, offsets, sizes,
4529 getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4530 }
4531 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4532 return success();
4533 }
4534};
4535
4536/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4537class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4538public:
4539 using Base::Base;
4540
4541 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4542 PatternRewriter &rewriter) const override {
4543
4544 Value splat = getScalarSplatSource(op.getSource());
4545 if (!splat)
4546 return failure();
4547 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4548 return success();
4549 }
4550};
4551
4552/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4553/// slice is contiguous, into extract and shape_cast.
4554///
4555/// Example:
4556/// Before:
4557/// %1 = vector.extract_strided_slice %arg0 {
4558/// offsets = [0, 0, 0, 0, 0],
4559/// sizes = [1, 1, 1, 1, 8],
4560/// strides = [1, 1, 1, 1, 1]
4561/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4562/// After:
4563/// %0 = vector.extract %arg0[0, 0, 0, 0]
4564/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4565/// %1 = vector.shape_cast %0
4566/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4567///
4568class ContiguousExtractStridedSliceToExtract final
4569 : public OpRewritePattern<ExtractStridedSliceOp> {
4570public:
4571 using Base::Base;
4572
4573 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4574 PatternRewriter &rewriter) const override {
4575 if (op.hasNonUnitStrides())
4576 return failure();
4577 Value source = op.getOperand();
4578 auto sourceType = cast<VectorType>(source.getType());
4579 if (sourceType.isScalable() || sourceType.getRank() == 0)
4580 return failure();
4581
4582 // Compute the number of offsets to pass to ExtractOp::build. That is the
4583 // difference between the source rank and the desired slice rank. We walk
4584 // the dimensions from innermost out, and stop when the next slice dimension
4585 // is not full-size.
4586 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4587 int numOffsets;
4588 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4589 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4590 break;
4591 }
4592
4593 // If the created extract op would have no offsets, then this whole
4594 // extract_strided_slice is the identity and should have been handled by
4595 // other canonicalizations.
4596 if (numOffsets == 0)
4597 return failure();
4598
4599 // If not even the inner-most dimension is full-size, this op can't be
4600 // rewritten as an ExtractOp.
4601 if (numOffsets == sourceType.getRank() &&
4602 static_cast<int>(sizes.size()) == sourceType.getRank())
4603 return failure();
4604
4605 // The outer dimensions must have unit size.
4606 for (int i = 0; i < numOffsets; ++i) {
4607 if (sizes[i] != 1)
4608 return failure();
4609 }
4610
4611 // Avoid generating slices that have leading unit dimensions. The shape_cast
4612 // op that we create below would take bad generic fallback patterns
4613 // (ShapeCastOpRewritePattern).
4614 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4615 sizes[numOffsets] == 1) {
4616 ++numOffsets;
4617 }
4618
4619 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4620 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4621 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4622 extractOffsets);
4623 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4624 return success();
4625 }
4626};
4627
4628} // namespace
4629
4630void ExtractStridedSliceOp::getCanonicalizationPatterns(
4631 RewritePatternSet &results, MLIRContext *context) {
4632 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4633 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4634 results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4635 StridedSliceBroadcast, StridedSliceSplat,
4636 ContiguousExtractStridedSliceToExtract>(context);
4637}
4638
4639//===----------------------------------------------------------------------===//
4640// TransferReadOp
4641//===----------------------------------------------------------------------===//
4642
4643/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4644void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4645 VectorType vectorType, Value source,
4646 ValueRange indices, std::optional<Value> padding,
4647 AffineMapAttr permutationMapAttr,
4648 /*optional*/ ArrayAttr inBoundsAttr) {
4649
4650 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4651 if (!padding)
4652 padding = ub::PoisonOp::create(builder, result.location, elemType);
4653 build(builder, result, vectorType, source, indices, permutationMapAttr,
4654 *padding, /*mask=*/Value(), inBoundsAttr);
4655}
4656
4657/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4658void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4659 VectorType vectorType, Value source,
4660 ValueRange indices, std::optional<Value> padding,
4661 AffineMap permutationMap,
4662 std::optional<ArrayRef<bool>> inBounds) {
4663 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4664 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4665 ? builder.getBoolArrayAttr(inBounds.value())
4666 : builder.getBoolArrayAttr(
4667 SmallVector<bool>(vectorType.getRank(), false));
4668 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4669 if (!padding)
4670 padding = ub::PoisonOp::create(builder, result.location, elemType);
4671 build(builder, result, vectorType, source, indices, *padding,
4672 permutationMapAttr, inBoundsAttr);
4673}
4674
4675/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4676void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4677 VectorType vectorType, Value source,
4678 ValueRange indices, std::optional<Value> padding,
4679 std::optional<ArrayRef<bool>> inBounds) {
4680 AffineMap permutationMap = getTransferMinorIdentityMap(
4681 llvm::cast<ShapedType>(source.getType()), vectorType);
4682 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4683 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4684 ? builder.getBoolArrayAttr(inBounds.value())
4685 : builder.getBoolArrayAttr(
4686 SmallVector<bool>(vectorType.getRank(), false));
4687 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4688 if (!padding)
4689 padding = ub::PoisonOp::create(builder, result.location, elemType);
4690 build(builder, result, vectorType, source, indices, permutationMapAttr,
4691 *padding,
4692 /*mask=*/Value(), inBoundsAttr);
4693}
4694
4695template <typename EmitFun>
4696static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4697 EmitFun emitOpError) {
4698 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4699 for (auto expr : permutationMap.getResults()) {
4700 auto dim = dyn_cast<AffineDimExpr>(expr);
4701 auto zero = dyn_cast<AffineConstantExpr>(expr);
4702 if (zero) {
4703 if (zero.getValue() != 0) {
4704 return emitOpError(
4705 "requires a projected permutation_map (at most one dim or the zero "
4706 "constant can appear in each result)");
4707 }
4708 continue;
4709 }
4710 if (!dim) {
4711 return emitOpError("requires a projected permutation_map (at most one "
4712 "dim or the zero constant can appear in each result)");
4713 }
4714 if (seen[dim.getPosition()]) {
4715 return emitOpError(
4716 "requires a permutation_map that is a permutation (found one dim "
4717 "used more than once)");
4718 }
4719 seen[dim.getPosition()] = true;
4720 }
4721 return success();
4722}
4723
4724static LogicalResult
4725verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4726 VectorType vectorType, VectorType maskType,
4727 VectorType inferredMaskType, AffineMap permutationMap,
4728 ArrayAttr inBounds) {
4729 if (op->hasAttr("masked")) {
4730 return op->emitOpError("masked attribute has been removed. "
4731 "Use in_bounds instead.");
4732 }
4733
4734 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4735 return op->emitOpError(
4736 "requires source to be a memref or ranked tensor type");
4737
4738 auto elementType = shapedType.getElementType();
4739 DataLayout dataLayout = DataLayout::closest(op);
4740 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4741 // Memref or tensor has vector element type.
4742 unsigned sourceVecSize =
4743 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4744 vectorElementType.getShape().back();
4745 unsigned resultVecSize =
4746 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4747 vectorType.getShape().back();
4748 if (resultVecSize % sourceVecSize != 0)
4749 return op->emitOpError(
4750 "requires the bitwidth of the minor 1-D vector to be an integral "
4751 "multiple of the bitwidth of the minor 1-D vector of the source");
4752
4753 unsigned sourceVecEltRank = vectorElementType.getRank();
4754 unsigned resultVecRank = vectorType.getRank();
4755 if (sourceVecEltRank > resultVecRank)
4756 return op->emitOpError(
4757 "requires source vector element and vector result ranks to match.");
4758 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4759 // Check that permutation map results match 'rankOffset' of vector type.
4760 if (permutationMap.getNumResults() != rankOffset)
4761 return op->emitOpError("requires a permutation_map with result dims of "
4762 "the same rank as the vector type");
4763
4764 if (maskType)
4765 return op->emitOpError("does not support masks with vector element type");
4766 } else {
4767 // Memref or tensor has scalar element type.
4768 unsigned minorSize =
4769 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4770 unsigned resultVecSize =
4771 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4772 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4773 return op->emitOpError(
4774 "requires the bitwidth of the minor 1-D vector to be an integral "
4775 "multiple of the bitwidth of the source element type");
4776
4777 // Check that permutation map results match rank of vector type.
4778 if (permutationMap.getNumResults() != vectorType.getRank())
4779 return op->emitOpError("requires a permutation_map with result dims of "
4780 "the same rank as the vector type");
4781 }
4782
4783 if (permutationMap.getNumSymbols() != 0)
4784 return op->emitOpError("requires permutation_map without symbols");
4785
4786 if (permutationMap.getNumInputs() != shapedType.getRank())
4787 return op->emitOpError("requires a permutation_map with input dims of the "
4788 "same rank as the source type");
4789
4790 if (maskType && maskType != inferredMaskType)
4791 return op->emitOpError("inferred mask type (")
4792 << inferredMaskType << ") and mask operand type (" << maskType
4793 << ") don't match";
4794
4795 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4796 return op->emitOpError("expects the in_bounds attr of same rank "
4797 "as permutation_map results: ")
4798 << AffineMapAttr::get(permutationMap)
4799 << " vs inBounds of size: " << inBounds.size();
4800
4801 return success();
4802}
4803
4804static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4805 SmallVector<StringRef, 3> elidedAttrs;
4806 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4807 if (op.getPermutationMap().isMinorIdentity())
4808 elidedAttrs.push_back(op.getPermutationMapAttrName());
4809 // Elide in_bounds attribute if all dims are out-of-bounds.
4810 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4811 elidedAttrs.push_back(op.getInBoundsAttrName());
4812 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4813}
4814
4815void TransferReadOp::print(OpAsmPrinter &p) {
4816 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
4817 if (getMask())
4818 p << ", " << getMask();
4819 printTransferAttrs(p, *this);
4820 p << " : " << getShapedType() << ", " << getVectorType();
4821}
4822
4823VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4824 AffineMap permMap) {
4825 auto i1Type = IntegerType::get(permMap.getContext(), 1);
4826 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4827 assert(invPermMap && "Inversed permutation map couldn't be computed");
4828 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4829
4830 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4831 // 0-D mask into a single-element 1-D mask.
4832 if (maskShape.empty())
4833 maskShape.push_back(1);
4834
4835 SmallVector<bool> scalableDims =
4836 applyPermutationMap(invPermMap, vecType.getScalableDims());
4837
4838 return VectorType::get(maskShape, i1Type, scalableDims);
4839}
4840
4841ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4842 auto &builder = parser.getBuilder();
4843 SMLoc typesLoc;
4849 // Parsing with support for paddingValue.
4850 if (parser.parseOperand(sourceInfo) ||
4852 parser.parseComma() || parser.parseOperand(paddingInfo))
4853 return failure();
4854 ParseResult hasMask = parser.parseOptionalComma();
4855 if (hasMask.succeeded()) {
4856 if (parser.parseOperand(maskInfo))
4857 return failure();
4858 }
4859 if (parser.parseOptionalAttrDict(result.attributes) ||
4860 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4861 return failure();
4862 if (types.size() != 2)
4863 return parser.emitError(typesLoc, "requires two types");
4864 auto indexType = builder.getIndexType();
4865 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4866 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4867 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4868 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4869 if (!vectorType)
4870 return parser.emitError(typesLoc, "requires vector type");
4871 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4872 Attribute permMapAttr = result.attributes.get(permMapAttrName);
4873 AffineMap permMap;
4874 if (!permMapAttr) {
4875 if (shapedType.getRank() <
4876 getEffectiveVectorRankForXferOp(shapedType, vectorType))
4877 return parser.emitError(typesLoc,
4878 "expected a custom permutation_map when "
4879 "rank(source) != rank(destination)");
4880 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4881 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4882 } else {
4883 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4884 }
4885 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4886 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4887 if (!inBoundsAttr) {
4888 result.addAttribute(inBoundsAttrName,
4889 builder.getBoolArrayAttr(
4890 SmallVector<bool>(permMap.getNumResults(), false)));
4891 }
4892 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4893 parser.resolveOperands(indexInfo, indexType, result.operands) ||
4894 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4895 result.operands))
4896 return failure();
4897 if (hasMask.succeeded()) {
4898 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4899 return parser.emitError(
4900 maskInfo.location, "does not support masks with vector element type");
4901 if (vectorType.getRank() != permMap.getNumResults()) {
4902 return parser.emitError(typesLoc,
4903 "expected the same rank for the vector and the "
4904 "results of the permutation map");
4905 }
4906 // Instead of adding the mask type as an op type, compute it based on the
4907 // vector type and the permutation map (to keep the type signature small).
4908 auto maskType = inferTransferOpMaskType(vectorType, permMap);
4909 if (parser.resolveOperand(maskInfo, maskType, result.operands))
4910 return failure();
4911 }
4912 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4913 builder.getDenseI32ArrayAttr(
4914 {1, static_cast<int32_t>(indexInfo.size()), 1,
4915 static_cast<int32_t>(hasMask.succeeded())}));
4916 return parser.addTypeToList(vectorType, result.types);
4917}
4918
4919LogicalResult TransferReadOp::verify() {
4920 // Consistency of elemental types in source and vector.
4921 ShapedType shapedType = getShapedType();
4922 VectorType vectorType = getVectorType();
4923 VectorType maskType = getMaskType();
4924 auto paddingType = getPadding().getType();
4925 auto permutationMap = getPermutationMap();
4926 VectorType inferredMaskType =
4927 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4928 : VectorType();
4929 auto sourceElementType = shapedType.getElementType();
4930
4931 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4932 return emitOpError("requires ") << shapedType.getRank() << " indices";
4933
4934 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4935 shapedType, vectorType, maskType,
4936 inferredMaskType, permutationMap, getInBounds())))
4937 return failure();
4938
4939 if (auto sourceVectorElementType =
4940 llvm::dyn_cast<VectorType>(sourceElementType)) {
4941 // Source has vector element type.
4942 // Check that 'sourceVectorElementType' and 'paddingType' types match.
4943 if (sourceVectorElementType != paddingType)
4944 return emitOpError(
4945 "requires source element type and padding type to match.");
4946
4947 } else {
4948 // Check that 'paddingType' is valid to store in a vector type.
4949 if (!VectorType::isValidElementType(paddingType))
4950 return emitOpError("requires valid padding vector elemental type");
4951
4952 // Check that padding type and vector element types match.
4953 if (paddingType != sourceElementType)
4954 return emitOpError(
4955 "requires formal padding and source of the same elemental type");
4956 }
4957
4958 return verifyPermutationMap(permutationMap,
4959 [&](Twine t) { return emitOpError(t); });
4960}
4961
4962// MaskableOpInterface methods.
4963
4964/// Returns the mask type expected by this operation. Mostly used for
4965/// verification purposes. It requires the operation to be vectorized."
4966Type TransferReadOp::getExpectedMaskType() {
4967 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4968}
4969
4970//===----------------------------------------------------------------------===//
4971// TransferReadOp: VectorTransferOpInterface methods.
4972//===----------------------------------------------------------------------===//
4973VectorType TransferReadOp::getVectorType() {
4974 return cast<VectorType>(getVector().getType());
4975}
4976
4977template <typename TransferOp>
4978static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4979 // TODO: support more aggressive createOrFold on:
4980 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4981 if (op.getShapedType().isDynamicDim(indicesIdx))
4982 return false;
4983 Value index = op.getIndices()[indicesIdx];
4984 std::optional<int64_t> cstOp = getConstantIntValue(index);
4985 if (!cstOp.has_value())
4986 return false;
4987
4988 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4989 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4990
4991 return cstOp.value() + vectorSize <= sourceSize;
4992}
4993
4994template <typename TransferOp>
4995static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4996 // TODO: support 0-d corner case.
4997 // TODO: Be less conservative.
4998 if (op.getTransferRank() == 0)
4999 return failure();
5000 AffineMap permutationMap = op.getPermutationMap();
5001 bool changed = false;
5002 SmallVector<bool, 4> newInBounds;
5003 newInBounds.reserve(op.getTransferRank());
5004 // Idxs of non-bcast dims - used when analysing bcast dims.
5005 SmallVector<unsigned> nonBcastDims;
5006
5007 // 1. Process non-broadcast dims
5008 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5009 // 1.1. Already marked as in-bounds, nothing to see here.
5010 if (op.isDimInBounds(i)) {
5011 newInBounds.push_back(true);
5012 continue;
5013 }
5014 // 1.2. Currently out-of-bounds, check whether we can statically determine
5015 // it is inBounds.
5016 bool inBounds = false;
5017 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5018 if (dimExpr) {
5019 inBounds = isInBounds(op, /*resultIdx=*/i,
5020 /*indicesIdx=*/dimExpr.getPosition());
5021 nonBcastDims.push_back(i);
5022 }
5023
5024 newInBounds.push_back(inBounds);
5025 // We commit the pattern if it is "more inbounds".
5026 changed |= inBounds;
5027 }
5028
5029 // 2. Handle broadcast dims
5030 // If all non-broadcast dims are "in bounds", then all bcast dims should be
5031 // "in bounds" as well.
5032 bool allNonBcastDimsInBounds = llvm::all_of(
5033 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5034 if (allNonBcastDimsInBounds) {
5035 for (size_t idx : permutationMap.getBroadcastDims()) {
5036 changed |= !newInBounds[idx];
5037 newInBounds[idx] = true;
5038 }
5039 }
5040
5041 if (!changed)
5042 return failure();
5043 // OpBuilder is only used as a helper to build an I64ArrayAttr.
5044 OpBuilder b(op.getContext());
5045 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5046 return success();
5047}
5048
5049template <typename TransferOp>
5050static LogicalResult foldTransferFullMask(TransferOp op) {
5051 auto mask = op.getMask();
5052 if (!mask)
5053 return failure();
5054
5056 return failure();
5057
5058 op.getMaskMutable().clear();
5059 return success();
5060}
5061
5062/// ```
5063/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5064/// : vector<1x4xf32>, tensor<4x4xf32>
5065/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5066/// : tensor<4x4xf32>, vector<1x4xf32>
5067/// ```
5068/// -> Folds into
5069/// ```
5070/// %v0
5071/// ```
5072static Value foldRAW(TransferReadOp readOp) {
5073 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5074 return {};
5075 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5076 while (defWrite) {
5077 if (checkSameValueRAW(defWrite, readOp))
5078 return defWrite.getVector();
5080 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5081 cast<VectorTransferOpInterface>(readOp.getOperation())))
5082 break;
5083 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5084 }
5085 return {};
5086}
5087
5088OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5089 if (Value vec = foldRAW(*this))
5090 return vec;
5091 /// transfer_read(memrefcast) -> transfer_read
5092 if (succeeded(foldTransferInBoundsAttribute(*this)))
5093 return getResult();
5094 if (succeeded(foldTransferFullMask(*this)))
5095 return getResult();
5096 if (succeeded(memref::foldMemRefCast(*this)))
5097 return getResult();
5098 if (succeeded(tensor::foldTensorCast(*this)))
5099 return getResult();
5100 return OpFoldResult();
5101}
5102
5103std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5104 return llvm::to_vector<4>(getVectorType().getShape());
5105}
5106
5107void TransferReadOp::getEffects(
5108 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5109 &effects) {
5110 if (llvm::isa<MemRefType>(getShapedType()))
5111 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5112 SideEffects::DefaultResource::get());
5113}
5114
5115Speculation::Speculatability TransferReadOp::getSpeculatability() {
5116 if (hasPureTensorSemantics())
5119}
5120
5121/// Given a projected permutation, inverse an affine map, making the unused dims
5122/// 0 in the result.
5123static AffineMap inverseWithUnusedDims(AffineMap map) {
5124 assert(map.isProjectedPermutation() &&
5125 "expected a projected permutation map");
5126 SmallVector<AffineExpr> results(map.getNumInputs(),
5128 for (auto [idx, result] : llvm::enumerate(map.getResults())) {
5129 // We should only have dim exprs because this is a projected permutation.
5130 int64_t pos = cast<AffineDimExpr>(result).getPosition();
5131 results[pos] = getAffineDimExpr(idx, map.getContext());
5132 }
5133 return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
5134 results, map.getContext());
5135}
5136
5137namespace {
5138/// Store to load forwarding for transfer operations with permuation maps.
5139/// Even if the permutation maps are different we can still propagate the store
5140/// into the load if the size of the dimensions read and written match. Then we
5141/// can replace the transfer_read + transfer_write by vector.broadcast and
5142/// vector.transpose.
5143/// Example:
5144/// ```
5145/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5146/// {in_bounds = [true, true],
5147/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5148/// vector<4x1xf32>, tensor<4x4x4xf32>
5149/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5150/// {in_bounds = [true, true, true, true],
5151/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5152/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5153/// ```
5154/// To:
5155/// ```
5156/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5157/// %r = vector.transpose %0, [3, 0, 2, 1] :
5158/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5159/// ```
5160struct TransferReadAfterWriteToBroadcast
5161 : public OpRewritePattern<TransferReadOp> {
5162 using Base::Base;
5163
5164 LogicalResult matchAndRewrite(TransferReadOp readOp,
5165 PatternRewriter &rewriter) const override {
5166 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5167 if (!defWrite)
5168 return failure();
5169 // Bail if we need an alias analysis.
5170 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5171 return failure();
5172 // Bail in the masked case (too complex atm and needed to properly account
5173 // for padding).
5174 if (readOp.getMask() || defWrite.getMask())
5175 return failure();
5176 // If indices are not the same a shift may be required, bail.
5177 if (readOp.getIndices() != defWrite.getIndices())
5178 return failure();
5179 // Bail if we need a bounds analysis.
5180 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5181 return failure();
5182 // TODO: If the written transfer chunk is a superset of the read transfer
5183 // chunk we could do an extract_strided_slice.
5184 if (readOp.getTransferChunkAccessed() !=
5185 defWrite.getTransferChunkAccessed())
5186 return failure();
5187 // WriteMap: tensor -> w_vec
5188 // ReadMap: tensor -> r_vec
5189 //
5190 // inv(WriteMap): w_vec -> tensor
5191 // inv(WriteMap) o ReadMap: w_vec -> r_vec
5192 AffineMap readMap = readOp.getPermutationMap();
5193 AffineMap writeMap = defWrite.getPermutationMap();
5194 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5195 AffineMap composedMap = readMap.compose(invWriteMap);
5196 // If there are any unused dims in the composedMap, we have to drop some
5197 // unit dims from the written vector before we can do transpose(broadcast).
5198 // TODO: Support this case.
5199 if (getUnusedDimsBitVector(composedMap).any())
5200 return failure();
5201 // readVec = transpose(broadcast(writeVec))
5202 //
5203 // Build a transpose permutation for the above transpose operation.
5204 //
5205 // Treat the composed map as having extra leading dimensions which are
5206 // the broadcasted dimensions, and treat the zeros as these new broadcasted
5207 // dimensions.
5208 SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
5209 int64_t numBroadcastedDims = broadcastedDims.size();
5210 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5211 invPerm.resize(composedMap.getNumResults());
5212 for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
5213 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5214 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5215 invPerm[effectiveDim] = idx;
5216 }
5217 }
5218 // Applying the inverse permutation on the readVecTy will give us the
5219 // broadcast result type.
5220 VectorType readVecTy = readOp.getVectorType();
5221 SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
5222 auto broadcastedVecTy =
5223 VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
5224 readVecTy.getElementType(),
5225 applyPermutation(readVecTy.getScalableDims(), invPerm));
5226 // Build the transpose(broadcast) transformation.
5227 Value vec = defWrite.getVector();
5228 Location loc = readOp.getLoc();
5229 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5230 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
5231 return success();
5232 }
5233};
5234} // namespace
5235
5236void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5237 MLIRContext *context) {
5238 results.add<TransferReadAfterWriteToBroadcast>(context);
5239}
5240
5241FailureOr<std::optional<SmallVector<Value>>>
5242TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5243 if (!hasPureBufferSemantics())
5244 return failure();
5246 getResult());
5247}
5248
5249//===----------------------------------------------------------------------===//
5250// TransferWriteOp
5251//===----------------------------------------------------------------------===//
5252
5253/// 1. Builder with type inference.
5254void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5255 Value vector, Value dest, ValueRange indices,
5256 AffineMapAttr permutationMapAttr,
5257 /*optional*/ Value mask,
5258 /*optional*/ ArrayAttr inBoundsAttr) {
5259 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5260 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5261 mask, inBoundsAttr);
5262}
5263
5264/// 2. Builder with type inference that sets an empty mask (variant with attrs).
5265void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5266 Value vector, Value dest, ValueRange indices,
5267 AffineMapAttr permutationMapAttr,
5268 /*optional*/ ArrayAttr inBoundsAttr) {
5269 build(builder, result, vector, dest, indices, permutationMapAttr,
5270 /*mask=*/Value(), inBoundsAttr);
5271}
5272
5273/// 3. Builder with type inference that sets an empty mask (variant without
5274/// attrs)
5275void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5276 Value vector, Value dest, ValueRange indices,
5277 AffineMap permutationMap,
5278 std::optional<ArrayRef<bool>> inBounds) {
5279 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5280 auto inBoundsAttr =
5281 (inBounds && !inBounds.value().empty())
5282 ? builder.getBoolArrayAttr(inBounds.value())
5283 : builder.getBoolArrayAttr(SmallVector<bool>(
5284 llvm::cast<VectorType>(vector.getType()).getRank(), false));
5285 build(builder, result, vector, dest, indices, permutationMapAttr,
5286 /*mask=*/Value(), inBoundsAttr);
5287}
5288
5289/// 4. Builder with type inference that sets an empty mask and sets permutation
5290/// map to 'getMinorIdentityMap'.
5291void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5292 Value vector, Value dest, ValueRange indices,
5293 std::optional<ArrayRef<bool>> inBounds) {
5294 auto vectorType = llvm::cast<VectorType>(vector.getType());
5295 AffineMap permutationMap = getTransferMinorIdentityMap(
5296 llvm::cast<ShapedType>(dest.getType()), vectorType);
5297 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5298}
5299
5300ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5301 OperationState &result) {
5302 auto &builder = parser.getBuilder();
5303 SMLoc typesLoc;
5304 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5305 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5306 SmallVector<Type, 2> types;
5307 OpAsmParser::UnresolvedOperand maskInfo;
5308 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5309 parser.parseOperand(sourceInfo) ||
5310 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5311 return failure();
5312 ParseResult hasMask = parser.parseOptionalComma();
5313 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5314 return failure();
5315 if (parser.parseOptionalAttrDict(result.attributes) ||
5316 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5317 return failure();
5318 if (types.size() != 2)
5319 return parser.emitError(typesLoc, "requires two types");
5320 auto indexType = builder.getIndexType();
5321 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5322 if (!vectorType)
5323 return parser.emitError(typesLoc, "requires vector type");
5324 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5325 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5326 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5327 auto permMapAttrName =
5328 TransferWriteOp::getPermutationMapAttrName(result.name);
5329 auto permMapAttr = result.attributes.get(permMapAttrName);
5330 AffineMap permMap;
5331 if (!permMapAttr) {
5332 if (shapedType.getRank() <
5333 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5334 return parser.emitError(typesLoc,
5335 "expected a custom permutation_map when "
5336 "rank(source) != rank(destination)");
5337 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5338 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5339 } else {
5340 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5341 }
5342 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5343 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5344 if (!inBoundsAttr) {
5345 result.addAttribute(inBoundsAttrName,
5346 builder.getBoolArrayAttr(
5347 SmallVector<bool>(permMap.getNumResults(), false)));
5348 }
5349 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5350 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5351 parser.resolveOperands(indexInfo, indexType, result.operands))
5352 return failure();
5353 if (hasMask.succeeded()) {
5354 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5355 return parser.emitError(
5356 maskInfo.location, "does not support masks with vector element type");
5357 if (vectorType.getRank() != permMap.getNumResults()) {
5358 return parser.emitError(typesLoc,
5359 "expected the same rank for the vector and the "
5360 "results of the permutation map");
5361 }
5362 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5363 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5364 return failure();
5365 }
5366 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5367 builder.getDenseI32ArrayAttr(
5368 {1, 1, static_cast<int32_t>(indexInfo.size()),
5369 static_cast<int32_t>(hasMask.succeeded())}));
5370 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5371 parser.addTypeToList(shapedType, result.types));
5372}
5373
5374void TransferWriteOp::print(OpAsmPrinter &p) {
5375 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5376 if (getMask())
5377 p << ", " << getMask();
5378 printTransferAttrs(p, *this);
5379 p << " : " << getVectorType() << ", " << getShapedType();
5380}
5381
5382LogicalResult TransferWriteOp::verify() {
5383 // Consistency of elemental types in shape and vector.
5384 ShapedType shapedType = getShapedType();
5385 VectorType vectorType = getVectorType();
5386 VectorType maskType = getMaskType();
5387 auto permutationMap = getPermutationMap();
5388 VectorType inferredMaskType =
5389 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5390 : VectorType();
5391
5392 if (llvm::size(getIndices()) != shapedType.getRank())
5393 return emitOpError("requires ") << shapedType.getRank() << " indices";
5394
5395 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5396 // as the semantics is unclear. This can be revisited later if necessary.
5397 if (hasBroadcastDim())
5398 return emitOpError("should not have broadcast dimensions");
5399
5400 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5401 shapedType, vectorType, maskType,
5402 inferredMaskType, permutationMap, getInBounds())))
5403 return failure();
5404
5405 return verifyPermutationMap(permutationMap,
5406 [&](Twine t) { return emitOpError(t); });
5407}
5408
5409//===----------------------------------------------------------------------===//
5410// TransferWriteOp: MaskableOpInterface methods.
5411//===----------------------------------------------------------------------===//
5412
5413/// Returns the mask type expected by this operation. Mostly used for
5414/// verification purposes.
5415Type TransferWriteOp::getExpectedMaskType() {
5416 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5417}
5418
5419//===----------------------------------------------------------------------===//
5420// TransferWriteOp: VectorTransferOpInterface methods.
5421//===----------------------------------------------------------------------===//
5422Value TransferWriteOp::getVector() { return getOperand(0); }
5423VectorType TransferWriteOp::getVectorType() {
5424 return cast<VectorType>(getValueToStore().getType());
5425}
5426
5427//===----------------------------------------------------------------------===//
5428// TransferWriteOp: fold methods.
5429//===----------------------------------------------------------------------===//
5430/// Fold:
5431/// ```
5432/// %t1 = ...
5433/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5434/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5435/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5436/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5437/// ```
5438///
5439/// into:
5440///
5441/// ```
5442/// %t0
5443/// ```
5444///
5445/// The producer of t1 may or may not be DCE'd depending on whether it is a
5446/// block argument or has side effects.
5447static LogicalResult foldReadInitWrite(TransferWriteOp write,
5448 ArrayRef<Attribute>,
5449 SmallVectorImpl<OpFoldResult> &results) {
5450 // TODO: support 0-d corner case.
5451 if (write.getTransferRank() == 0)
5452 return failure();
5453 auto rankedTensorType =
5454 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5455 // If not operating on tensors, bail.
5456 if (!rankedTensorType)
5457 return failure();
5458 // If no read, bail.
5459 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5460 if (!read)
5461 return failure();
5462 // TODO: support 0-d corner case.
5463 if (read.getTransferRank() == 0)
5464 return failure();
5465 // For now, only accept minor identity. Future: composition is minor identity.
5466 if (!read.getPermutationMap().isMinorIdentity() ||
5467 !write.getPermutationMap().isMinorIdentity())
5468 return failure();
5469 // Bail on mismatching ranks.
5470 if (read.getTransferRank() != write.getTransferRank())
5471 return failure();
5472 // Bail on potential out-of-bounds accesses.
5473 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5474 return failure();
5475 // Tensor types must be the same.
5476 if (read.getBase().getType() != rankedTensorType)
5477 return failure();
5478 // Vector types must be the same.
5479 if (read.getVectorType() != write.getVectorType())
5480 return failure();
5481 // Vector and Tensor shapes must match.
5482 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5483 return failure();
5484 // If any index is nonzero.
5485 auto isNotConstantZero = [](Value v) {
5486 auto cstOp = getConstantIntValue(v);
5487 return !cstOp.has_value() || cstOp.value() != 0;
5488 };
5489 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5490 llvm::any_of(write.getIndices(), isNotConstantZero))
5491 return failure();
5492 // Success.
5493 results.push_back(read.getBase());
5494 return success();
5495}
5496
5497static bool checkSameValueWAR(vector::TransferReadOp read,
5498 vector::TransferWriteOp write) {
5499 return read.getBase() == write.getBase() &&
5500 read.getIndices() == write.getIndices() &&
5501 read.getPermutationMap() == write.getPermutationMap() &&
5502 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5503 !write.getMask();
5504}
5505/// Fold transfer_write write after read:
5506/// ```
5507/// %t0 = ...
5508/// %v = vector.transfer_read %t0[%c0...] :
5509/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5510/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5511/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5512/// ```
5513///
5514/// into:
5515///
5516/// ```
5517/// %t0
5518/// ```
5519static LogicalResult foldWAR(TransferWriteOp write,
5520 SmallVectorImpl<OpFoldResult> &results) {
5521 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5522 return failure();
5523 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5524 if (!read)
5525 return failure();
5526
5527 if (!checkSameValueWAR(read, write))
5528 return failure();
5529 results.push_back(read.getBase());
5530 return success();
5531}
5532
5533LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5534 SmallVectorImpl<OpFoldResult> &results) {
5535 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5536 return success();
5537 if (succeeded(foldWAR(*this, results)))
5538 return success();
5539 if (succeeded(foldTransferInBoundsAttribute(*this)))
5540 return success();
5541 if (succeeded(foldTransferFullMask(*this)))
5542 return success();
5543 return memref::foldMemRefCast(*this);
5544}
5545
5546//===----------------------------------------------------------------------===//
5547// TransferWriteOp: other methods.
5548//===----------------------------------------------------------------------===//
5549std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5550 return llvm::to_vector<4>(getVectorType().getShape());
5551}
5552
5553void TransferWriteOp::getEffects(
5554 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5555 &effects) {
5556 if (llvm::isa<MemRefType>(getShapedType()))
5557 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5558 SideEffects::DefaultResource::get());
5559}
5560
5561Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5562 if (hasPureTensorSemantics())
5565}
5566
5567namespace {
5568/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5569/// DCE
5570/// ```
5571/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5572/// : vector<1x4xf32>, tensor<4x4xf32>
5573/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5574/// : vector<1x4xf32>, tensor<4x4xf32>
5575/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5576/// : vector<1x4xf32>, tensor<4x4xf32>
5577/// ```
5578///
5579/// into:
5580///
5581/// ```
5582/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5583/// : vector<1x4xf32>, tensor<4x4xf32>
5584/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5585/// : vector<1x4xf32>, tensor<4x4xf32>
5586/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5587/// : vector<1x4xf32>, tensor<4x4xf32>
5588/// ```
5589///
5590/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5591/// any other uses.
5592class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5593public:
5594 using Base::Base;
5595 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5596 PatternRewriter &rewriter) const override {
5597 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5598 return failure();
5599 vector::TransferWriteOp writeToModify = writeOp;
5600
5601 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5602 while (defWrite) {
5603 if (checkSameValueWAW(writeOp, defWrite)) {
5604 rewriter.modifyOpInPlace(writeToModify, [&]() {
5605 writeToModify.getBaseMutable().assign(defWrite.getBase());
5606 });
5607 return success();
5608 }
5610 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5611 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5612 break;
5613 // If the previous write op doesn't have any other use we an safely look
5614 // at the previous store to see if it can be removed.
5615 if (!defWrite->hasOneUse())
5616 break;
5617 writeToModify = defWrite;
5618 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5619 }
5620 return failure();
5621 }
5622};
5623
5624/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5625/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5626/// overwritten and inserted into another tensor. After this rewrite, the
5627/// operations bufferize in-place since all of them work on the same slice.
5628///
5629/// For example:
5630/// ```mlir
5631/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5632/// : vector<8x16xf32>, tensor<8x16xf32>
5633/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5634/// : tensor<8x16xf32> to tensor<?x?xf32>
5635/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5636/// : tensor<?x?xf32> into tensor<27x37xf32>
5637/// ```
5638/// folds to
5639/// ```mlir
5640/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5641/// : tensor<27x37xf32> to tensor<?x?xf32>
5642/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5643/// : vector<8x16xf32>, tensor<?x?xf32>
5644/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5645/// : tensor<?x?xf32> into tensor<27x37xf32>
5646/// ```
5647struct SwapExtractSliceOfTransferWrite
5648 : public OpRewritePattern<tensor::InsertSliceOp> {
5649public:
5650 using Base::Base;
5651
5652 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5653 PatternRewriter &rewriter) const override {
5654 if (!insertOp.hasUnitStride())
5655 return failure();
5656 auto extractOp =
5657 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5658 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5659 return failure();
5660 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5661 if (!transferOp || !transferOp->hasOneUse())
5662 return failure();
5663
5664 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5665 // rank-reducing.
5666 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5667 return rewriter.notifyMatchFailure(insertOp,
5668 "use-def chain is rank-reducing");
5669 }
5670
5671 // Fail if tensor::ExtractSliceOp has non-zero offset.
5672 if (!extractOp.hasZeroOffset()) {
5673 return rewriter.notifyMatchFailure(insertOp,
5674 "ExtractSliceOp has non-zero offset");
5675 }
5676
5677 // Fail if tensor::TransferWriteOp has non-zero offset.
5678 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5679 return getConstantIntValue(value) == static_cast<int64_t>(0);
5680 })) {
5681 return rewriter.notifyMatchFailure(insertOp,
5682 "TranferWriteOp has non-zero offset");
5683 }
5684
5685 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5686 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5687 return rewriter.notifyMatchFailure(
5688 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5689 }
5690
5691 for (auto [insertSize, extractSize] :
5692 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5693 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5694 return rewriter.notifyMatchFailure(
5695 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5696 }
5697 }
5698
5699 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5700 assert(transferOp.getVectorType().hasStaticShape() &&
5701 "expected vector to have a static shape");
5702 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5703 SmallVector<int64_t> resultShape = applyPermutationMap(
5704 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5705 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5706 return rewriter.notifyMatchFailure(
5707 insertOp, "TransferWriteOp may not write the full tensor.");
5708 }
5709
5710 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5711 // Set all in_bounds to false and let the folder infer them.
5712 SmallVector<bool> newInBounds(vectorShape.size(), false);
5713 auto newExtractOp = tensor::ExtractSliceOp::create(
5714 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5715 insertOp.getDest(), insertOp.getMixedOffsets(),
5716 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5717 auto newTransferWriteOp = TransferWriteOp::create(
5718 rewriter, transferOp.getLoc(), transferOp.getVector(),
5719 newExtractOp.getResult(), transferOp.getIndices(),
5720 transferOp.getPermutationMapAttr(),
5721 rewriter.getBoolArrayAttr(newInBounds));
5722 rewriter.modifyOpInPlace(insertOp, [&]() {
5723 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5724 });
5725 return success();
5726 }
5727};
5728
5729} // namespace
5730
5731void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5732 MLIRContext *context) {
5733 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5734}
5735
5736FailureOr<std::optional<SmallVector<Value>>>
5737TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5738 if (!hasPureBufferSemantics())
5739 return failure();
5741 ValueRange());
5742}
5743
5744//===----------------------------------------------------------------------===//
5745// LoadOp
5746//===----------------------------------------------------------------------===//
5747
5748static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5749 VectorType vecTy,
5750 MemRefType memRefTy) {
5751 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5752 // need any strides limitations.
5753 if (!vecTy.isScalable() &&
5754 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5755 return success();
5756
5757 if (!memRefTy.isLastDimUnitStride())
5758 return op->emitOpError("most minor memref dim must have unit stride");
5759 return success();
5760}
5761
5762LogicalResult vector::LoadOp::verify() {
5763 VectorType resVecTy = getVectorType();
5764 MemRefType memRefTy = getMemRefType();
5765
5766 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5767 return failure();
5768
5769 if (memRefTy.getRank() < resVecTy.getRank())
5770 return emitOpError(
5771 "destination memref has lower rank than the result vector");
5772
5773 // Checks for vector memrefs.
5774 Type memElemTy = memRefTy.getElementType();
5775 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5776 if (memVecTy != resVecTy)
5777 return emitOpError("base memref and result vector types should match");
5778 memElemTy = memVecTy.getElementType();
5779 }
5780
5781 if (resVecTy.getElementType() != memElemTy)
5782 return emitOpError("base and result element types should match");
5783 if (llvm::size(getIndices()) != memRefTy.getRank())
5784 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5785 return success();
5786}
5787
5788OpFoldResult LoadOp::fold(FoldAdaptor) {
5789 if (succeeded(memref::foldMemRefCast(*this)))
5790 return getResult();
5791 return OpFoldResult();
5792}
5793
5794std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5795 return llvm::to_vector<4>(getVectorType().getShape());
5796}
5797
5798FailureOr<std::optional<SmallVector<Value>>>
5799LoadOp::bubbleDownCasts(OpBuilder &builder) {
5801 getResult());
5802}
5803
5804//===----------------------------------------------------------------------===//
5805// StoreOp
5806//===----------------------------------------------------------------------===//
5807
5808LogicalResult vector::StoreOp::verify() {
5809 VectorType valueVecTy = getVectorType();
5810 MemRefType memRefTy = getMemRefType();
5811
5812 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5813 return failure();
5814
5815 if (memRefTy.getRank() < valueVecTy.getRank())
5816 return emitOpError("source memref has lower rank than the vector to store");
5817
5818 // Checks for vector memrefs.
5819 Type memElemTy = memRefTy.getElementType();
5820 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5821 if (memVecTy != valueVecTy)
5822 return emitOpError(
5823 "base memref and valueToStore vector types should match");
5824 memElemTy = memVecTy.getElementType();
5825 }
5826
5827 if (valueVecTy.getElementType() != memElemTy)
5828 return emitOpError("base and valueToStore element type should match");
5829 if (llvm::size(getIndices()) != memRefTy.getRank())
5830 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5831 return success();
5832}
5833
5834LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5835 SmallVectorImpl<OpFoldResult> &results) {
5836 return memref::foldMemRefCast(*this);
5837}
5838
5839std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5840 return llvm::to_vector<4>(getVectorType().getShape());
5841}
5842
5843FailureOr<std::optional<SmallVector<Value>>>
5844StoreOp::bubbleDownCasts(OpBuilder &builder) {
5846 ValueRange());
5847}
5848
5849//===----------------------------------------------------------------------===//
5850// MaskedLoadOp
5851//===----------------------------------------------------------------------===//
5852
5853LogicalResult MaskedLoadOp::verify() {
5854 VectorType maskVType = getMaskVectorType();
5855 VectorType passVType = getPassThruVectorType();
5856 VectorType resVType = getVectorType();
5857 MemRefType memType = getMemRefType();
5858
5859 if (resVType.getElementType() != memType.getElementType())
5860 return emitOpError("base and result element type should match");
5861 if (llvm::size(getIndices()) != memType.getRank())
5862 return emitOpError("requires ") << memType.getRank() << " indices";
5863 if (resVType.getShape() != maskVType.getShape())
5864 return emitOpError("expected result shape to match mask shape");
5865 if (resVType != passVType)
5866 return emitOpError("expected pass_thru of same type as result type");
5867 return success();
5868}
5869
5870namespace {
5871class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5872public:
5873 using Base::Base;
5874 LogicalResult matchAndRewrite(MaskedLoadOp load,
5875 PatternRewriter &rewriter) const override {
5876 switch (getMaskFormat(load.getMask())) {
5878 rewriter.replaceOpWithNewOp<vector::LoadOp>(
5879 load, load.getType(), load.getBase(), load.getIndices());
5880 return success();
5882 rewriter.replaceOp(load, load.getPassThru());
5883 return success();
5885 return failure();
5886 }
5887 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5888 }
5889};
5890} // namespace
5891
5892void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5893 MLIRContext *context) {
5894 results.add<MaskedLoadFolder>(context);
5895}
5896
5897OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5898 if (succeeded(memref::foldMemRefCast(*this)))
5899 return getResult();
5900 return OpFoldResult();
5901}
5902
5903FailureOr<std::optional<SmallVector<Value>>>
5904MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
5906 getResult());
5907}
5908
5909//===----------------------------------------------------------------------===//
5910// MaskedStoreOp
5911//===----------------------------------------------------------------------===//
5912
5913LogicalResult MaskedStoreOp::verify() {
5914 VectorType maskVType = getMaskVectorType();
5915 VectorType valueVType = getVectorType();
5916 MemRefType memType = getMemRefType();
5917
5918 if (valueVType.getElementType() != memType.getElementType())
5919 return emitOpError("base and valueToStore element type should match");
5920 if (llvm::size(getIndices()) != memType.getRank())
5921 return emitOpError("requires ") << memType.getRank() << " indices";
5922 if (valueVType.getShape() != maskVType.getShape())
5923 return emitOpError("expected valueToStore shape to match mask shape");
5924 return success();
5925}
5926
5927namespace {
5928class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5929public:
5930 using Base::Base;
5931 LogicalResult matchAndRewrite(MaskedStoreOp store,
5932 PatternRewriter &rewriter) const override {
5933 switch (getMaskFormat(store.getMask())) {
5935 rewriter.replaceOpWithNewOp<vector::StoreOp>(
5936 store, store.getValueToStore(), store.getBase(), store.getIndices());
5937 return success();
5939 rewriter.eraseOp(store);
5940 return success();
5942 return failure();
5943 }
5944 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5945 }
5946};
5947} // namespace
5948
5949void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5950 MLIRContext *context) {
5951 results.add<MaskedStoreFolder>(context);
5952}
5953
5954LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5955 SmallVectorImpl<OpFoldResult> &results) {
5956 return memref::foldMemRefCast(*this);
5957}
5958
5959FailureOr<std::optional<SmallVector<Value>>>
5960MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
5962 ValueRange());
5963}
5964
5965//===----------------------------------------------------------------------===//
5966// GatherOp
5967//===----------------------------------------------------------------------===//
5968
5969LogicalResult GatherOp::verify() {
5970 VectorType indVType = getIndexVectorType();
5971 VectorType maskVType = getMaskVectorType();
5972 VectorType resVType = getVectorType();
5973 ShapedType baseType = getBaseType();
5974
5975 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5976 return emitOpError("requires base to be a memref or ranked tensor type");
5977
5978 if (resVType.getElementType() != baseType.getElementType())
5979 return emitOpError("base and result element type should match");
5980 if (llvm::size(getOffsets()) != baseType.getRank())
5981 return emitOpError("requires ") << baseType.getRank() << " indices";
5982 if (resVType.getShape() != indVType.getShape())
5983 return emitOpError("expected result dim to match indices dim");
5984 if (resVType.getShape() != maskVType.getShape())
5985 return emitOpError("expected result dim to match mask dim");
5986 if (resVType != getPassThruVectorType())
5987 return emitOpError("expected pass_thru of same type as result type");
5988 return success();
5989}
5990
5991// MaskableOpInterface methods.
5992
5993/// Returns the mask type expected by this operation. Mostly used for
5994/// verification purposes. It requires the operation to be vectorized."
5995Type GatherOp::getExpectedMaskType() {
5996 auto vecType = this->getIndexVectorType();
5997 return VectorType::get(vecType.getShape(),
5998 IntegerType::get(vecType.getContext(), /*width=*/1),
5999 vecType.getScalableDims());
6000}
6001
6002std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6003 return llvm::to_vector<4>(getVectorType().getShape());
6004}
6005
6006/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
6007static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6008 auto vecType = dyn_cast<VectorType>(indexVec.getType());
6009 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6010 return failure();
6011
6012 if (indexVec.getDefiningOp<StepOp>())
6013 return success();
6014
6015 DenseIntElementsAttr elements;
6016 if (!matchPattern(indexVec, m_Constant(&elements)))
6017 return failure();
6018
6019 return success(
6020 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6021}
6022
6023namespace {
6024class GatherFolder final : public OpRewritePattern<GatherOp> {
6025public:
6026 using Base::Base;
6027 LogicalResult matchAndRewrite(GatherOp gather,
6028 PatternRewriter &rewriter) const override {
6029 switch (getMaskFormat(gather.getMask())) {
6031 return failure(); // no unmasked equivalent
6033 rewriter.replaceOp(gather, gather.getPassThru());
6034 return success();
6036 return failure();
6037 }
6038 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6039 }
6040};
6041
6042/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6043/// maskedload. Only 1D fixed vectors are supported for now.
6044class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6045public:
6046 using Base::Base;
6047 LogicalResult matchAndRewrite(GatherOp op,
6048 PatternRewriter &rewriter) const override {
6049 if (!isa<MemRefType>(op.getBase().getType()))
6050 return rewriter.notifyMatchFailure(op, "base must be of memref type");
6051
6052 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6053 return failure();
6054
6055 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6056 op.getOffsets(), op.getMask(),
6057 op.getPassThru());
6058 return success();
6059 }
6060};
6061} // namespace
6062
6063void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6064 MLIRContext *context) {
6065 results.add<GatherFolder, FoldContiguousGather>(context);
6066}
6067
6068FailureOr<std::optional<SmallVector<Value>>>
6069GatherOp::bubbleDownCasts(OpBuilder &builder) {
6071 getResult());
6072}
6073
6074//===----------------------------------------------------------------------===//
6075// ScatterOp
6076//===----------------------------------------------------------------------===//
6077
6078LogicalResult ScatterOp::verify() {
6079 VectorType indVType = getIndexVectorType();
6080 VectorType maskVType = getMaskVectorType();
6081 VectorType valueVType = getVectorType();
6082 ShapedType baseType = getBaseType();
6083
6084 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6085 return emitOpError("requires base to be a memref or ranked tensor type");
6086
6087 if (valueVType.getElementType() != baseType.getElementType())
6088 return emitOpError("base and valueToStore element type should match");
6089 if (llvm::size(getOffsets()) != baseType.getRank())
6090 return emitOpError("requires ") << baseType.getRank() << " indices";
6091 if (valueVType.getShape() != indVType.getShape())
6092 return emitOpError("expected valueToStore dim to match indices dim");
6093 if (valueVType.getShape() != maskVType.getShape())
6094 return emitOpError("expected valueToStore dim to match mask dim");
6095 return success();
6096}
6097namespace {
6098class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6099public:
6100 using Base::Base;
6101 LogicalResult matchAndRewrite(ScatterOp scatter,
6102 PatternRewriter &rewriter) const override {
6103 ShapedType baseType = scatter.getBaseType();
6104 bool isMemRef = isa<MemRefType>(baseType);
6105 if (!isMemRef && !isa<RankedTensorType>(baseType))
6106 return failure();
6107
6108 // Memrefs have no result, so an all-false mask can simply erase the op.
6109 // Tensors carry the updated value, so we must replace uses with the
6110 // original base tensor instead of erasing.
6111 switch (getMaskFormat(scatter.getMask())) {
6113 return failure(); // no unmasked equivalent
6115 if (isMemRef)
6116 rewriter.eraseOp(scatter);
6117 else
6118 rewriter.replaceOp(scatter, scatter.getBase());
6119 return success();
6121 return failure();
6122 }
6123 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6124 }
6125};
6126
6127/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6128/// maskedstore. Only 1D fixed vectors are supported for now.
6129class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6130public:
6131 using Base::Base;
6132 LogicalResult matchAndRewrite(ScatterOp op,
6133 PatternRewriter &rewriter) const override {
6134 // Fold only for memrefs: the replacement uses maskedstore, which does not
6135 // support tensor bases. Tensor cases intentionally bail out.
6136 if (!isa<MemRefType>(op.getBase().getType()))
6137 return failure();
6138
6139 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6140 return failure();
6141
6142 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6143 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6144 return success();
6145 }
6146};
6147} // namespace
6148
6149void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6150 MLIRContext *context) {
6151 results.add<ScatterFolder, FoldContiguousScatter>(context);
6152}
6153
6154FailureOr<std::optional<SmallVector<Value>>>
6155ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6157 ValueRange());
6158}
6159
6160//===----------------------------------------------------------------------===//
6161// ExpandLoadOp
6162//===----------------------------------------------------------------------===//
6163
6164LogicalResult ExpandLoadOp::verify() {
6165 VectorType maskVType = getMaskVectorType();
6166 VectorType passVType = getPassThruVectorType();
6167 VectorType resVType = getVectorType();
6168 MemRefType memType = getMemRefType();
6169
6170 if (resVType.getElementType() != memType.getElementType())
6171 return emitOpError("base and result element type should match");
6172 if (llvm::size(getIndices()) != memType.getRank())
6173 return emitOpError("requires ") << memType.getRank() << " indices";
6174 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6175 return emitOpError("expected result dim to match mask dim");
6176 if (resVType != passVType)
6177 return emitOpError("expected pass_thru of same type as result type");
6178 return success();
6179}
6180
6181namespace {
6182class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6183public:
6184 using Base::Base;
6185 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6186 PatternRewriter &rewriter) const override {
6187 switch (getMaskFormat(expand.getMask())) {
6189 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6190 expand, expand.getType(), expand.getBase(), expand.getIndices());
6191 return success();
6193 rewriter.replaceOp(expand, expand.getPassThru());
6194 return success();
6196 return failure();
6197 }
6198 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6199 }
6200};
6201} // namespace
6202
6203void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6204 MLIRContext *context) {
6205 results.add<ExpandLoadFolder>(context);
6206}
6207
6208FailureOr<std::optional<SmallVector<Value>>>
6209ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6211 getResult());
6212}
6213
6214//===----------------------------------------------------------------------===//
6215// CompressStoreOp
6216//===----------------------------------------------------------------------===//
6217
6218LogicalResult CompressStoreOp::verify() {
6219 VectorType maskVType = getMaskVectorType();
6220 VectorType valueVType = getVectorType();
6221 MemRefType memType = getMemRefType();
6222
6223 if (valueVType.getElementType() != memType.getElementType())
6224 return emitOpError("base and valueToStore element type should match");
6225 if (llvm::size(getIndices()) != memType.getRank())
6226 return emitOpError("requires ") << memType.getRank() << " indices";
6227 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6228 return emitOpError("expected valueToStore dim to match mask dim");
6229 return success();
6230}
6231
6232namespace {
6233class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6234public:
6235 using Base::Base;
6236 LogicalResult matchAndRewrite(CompressStoreOp compress,
6237 PatternRewriter &rewriter) const override {
6238 switch (getMaskFormat(compress.getMask())) {
6240 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6241 compress, compress.getValueToStore(), compress.getBase(),
6242 compress.getIndices());
6243 return success();
6245 rewriter.eraseOp(compress);
6246 return success();
6248 return failure();
6249 }
6250 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6251 }
6252};
6253} // namespace
6254
6255void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6256 MLIRContext *context) {
6257 results.add<CompressStoreFolder>(context);
6258}
6259
6260FailureOr<std::optional<SmallVector<Value>>>
6261CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6263 ValueRange());
6264}
6265
6266//===----------------------------------------------------------------------===//
6267// ShapeCastOp
6268//===----------------------------------------------------------------------===//
6269
6270void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6271 SetIntRangeFn setResultRanges) {
6272 setResultRanges(getResult(), argRanges.front());
6273}
6274
6275std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6276 return llvm::to_vector<4>(getResultVectorType().getShape());
6277}
6278
6279LogicalResult ShapeCastOp::verify() {
6280
6281 VectorType sourceType = getSourceVectorType();
6282 VectorType resultType = getResultVectorType();
6283
6284 // Check that element type is preserved
6285 if (sourceType.getElementType() != resultType.getElementType())
6286 return emitOpError("has different source and result element types");
6287
6288 // Check that number of elements is preserved
6289 int64_t sourceNElms = sourceType.getNumElements();
6290 int64_t resultNElms = resultType.getNumElements();
6291 if (sourceNElms != resultNElms) {
6292 return emitOpError() << "has different number of elements at source ("
6293 << sourceNElms << ") and result (" << resultNElms
6294 << ")";
6295 }
6296
6297 // Check that (non-)scalability is preserved
6298 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6299 int64_t resultNScalableDims = resultType.getNumScalableDims();
6300 if (sourceNScalableDims != resultNScalableDims)
6301 return emitOpError() << "has different number of scalable dims at source ("
6302 << sourceNScalableDims << ") and result ("
6303 << resultNScalableDims << ")";
6304
6305 return success();
6306}
6307
6308/// Return true if `transpose` does not permute a pair of non-unit dims.
6309/// By `order preserving` we mean that the flattened versions of the input and
6310/// output vectors are (numerically) identical. In other words `transpose` is
6311/// effectively a shape cast.
6312static bool isOrderPreserving(TransposeOp transpose) {
6313 ArrayRef<int64_t> permutation = transpose.getPermutation();
6314 VectorType sourceType = transpose.getSourceVectorType();
6315 ArrayRef<int64_t> inShape = sourceType.getShape();
6316 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6317 auto isNonScalableUnitDim = [&](int64_t dim) {
6318 return inShape[dim] == 1 && !inDimIsScalable[dim];
6319 };
6320 int64_t current = 0;
6321 for (auto p : permutation) {
6322 if (!isNonScalableUnitDim(p)) {
6323 if (p < current) {
6324 return false;
6325 }
6326 current = p;
6327 }
6328 }
6329 return true;
6330}
6331
6332OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6333
6334 VectorType resultType = getType();
6335
6336 // No-op shape cast.
6337 if (getSource().getType() == resultType)
6338 return getSource();
6339
6340 // shape_cast(shape_cast(x)) -> shape_cast(x)
6341 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6342 setOperand(precedingShapeCast.getSource());
6343 return getResult();
6344 }
6345
6346 // shape_cast(transpose(x)) -> shape_cast(x)
6347 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6348 if (isOrderPreserving(transpose)) {
6349 setOperand(transpose.getVector());
6350 return getResult();
6351 }
6352 return {};
6353 }
6354
6355 // Y = shape_cast(broadcast(X))
6356 // -> X, if X and Y have same type
6357 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6358 if (bcastOp.getSourceType() == resultType)
6359 return bcastOp.getSource();
6360 }
6361
6362 // shape_cast(constant) -> constant
6363 if (auto denseAttr =
6364 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6365 return denseAttr.reshape(getType());
6366
6367 // shape_cast(poison) -> poison
6368 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6369 return ub::PoisonAttr::get(getContext());
6370
6371 return {};
6372}
6373
6374namespace {
6375
6376/// Helper function that computes a new vector type based on the input vector
6377/// type by removing the trailing one dims:
6378///
6379/// vector<4x1x1xi1> --> vector<4x1xi1>
6380///
6381static VectorType trimTrailingOneDims(VectorType oldType) {
6382 ArrayRef<int64_t> oldShape = oldType.getShape();
6383 ArrayRef<int64_t> newShape = oldShape;
6384
6385 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6386 ArrayRef<bool> newScalableDims = oldScalableDims;
6387
6388 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6389 newShape = newShape.drop_back(1);
6390 newScalableDims = newScalableDims.drop_back(1);
6391 }
6392
6393 // Make sure we have at least 1 dimension.
6394 // TODO: Add support for 0-D vectors.
6395 if (newShape.empty()) {
6396 newShape = oldShape.take_back();
6397 newScalableDims = oldScalableDims.take_back();
6398 }
6399
6400 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6401}
6402
6403/// Folds qualifying shape_cast(create_mask) into a new create_mask
6404///
6405/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6406/// dimension. If the input vector comes from `vector.create_mask` for which
6407/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6408/// to fold shape_cast into create_mask.
6409///
6410/// BEFORE:
6411/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6412/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6413/// AFTER:
6414/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6415class ShapeCastCreateMaskFolderTrailingOneDim final
6416 : public OpRewritePattern<ShapeCastOp> {
6417public:
6418 using Base::Base;
6419
6420 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6421 PatternRewriter &rewriter) const override {
6422 Value shapeOpSrc = shapeOp->getOperand(0);
6423 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6424 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6425 if (!createMaskOp && !constantMaskOp)
6426 return failure();
6427
6428 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6429 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6430
6431 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6432 if (newVecType != shapeOpResTy)
6433 return failure();
6434
6435 auto numDimsToDrop =
6436 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6437
6438 // No unit dims to drop
6439 if (!numDimsToDrop)
6440 return failure();
6441
6442 if (createMaskOp) {
6443 auto maskOperands = createMaskOp.getOperands();
6444 auto numMaskOperands = maskOperands.size();
6445
6446 // Check every mask dim size to see whether it can be dropped
6447 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6448 --i) {
6449 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6450 if (!constant || (constant.value() != 1))
6451 return failure();
6452 }
6453 SmallVector<Value> newMaskOperands =
6454 maskOperands.drop_back(numDimsToDrop);
6455
6456 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6457 newMaskOperands);
6458 return success();
6459 }
6460
6461 if (constantMaskOp) {
6462 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6463 auto numMaskOperands = maskDimSizes.size();
6464
6465 // Check every mask dim size to see whether it can be dropped
6466 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6467 --i) {
6468 if (maskDimSizes[i] != 1)
6469 return failure();
6470 }
6471
6472 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6473 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6474 newMaskOperands);
6475 return success();
6476 }
6477
6478 return failure();
6479 }
6480};
6481
6482/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
6483/// i) Y = ShapeCast(X), or
6484/// ii) Y = Broadcast(X)
6485/// If both (i) and (ii) are possible, (i) is chosen.
6486class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6487public:
6488 using Base::Base;
6489
6490 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6491 PatternRewriter &rewriter) const override {
6492 auto broadcastOp =
6493 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6494 if (!broadcastOp)
6495 return failure();
6496
6497 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6498 bool srcIsScalar = !srcVectorType;
6499
6500 // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
6501 // Example:
6502 // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
6503 // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
6504 // to
6505 // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
6506 if (srcVectorType) {
6507 if (srcVectorType.getNumElements() ==
6508 shapeCastOp.getResultVectorType().getNumElements()) {
6509 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6510 shapeCastOp, shapeCastOp.getResultVectorType(),
6511 broadcastOp.getSource());
6512 return success();
6513 }
6514 }
6515
6516 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6517 // Example
6518 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6519 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6520 // to
6521 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6522 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6523 if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
6524 BroadcastableToResult::Success) {
6525 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6526 shapeCastOp, dstVectorType, broadcastOp.getSource());
6527 return success();
6528 }
6529 return failure();
6530 }
6531};
6532
6533} // namespace
6534
6535void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6536 MLIRContext *context) {
6537 results
6538 .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6539 context);
6540}
6541
6542//===----------------------------------------------------------------------===//
6543// VectorBitCastOp
6544//===----------------------------------------------------------------------===//
6545
6546LogicalResult BitCastOp::verify() {
6547 auto sourceVectorType = getSourceVectorType();
6548 auto resultVectorType = getResultVectorType();
6549
6550 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6551 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6552 return emitOpError("dimension size mismatch at: ") << i;
6553 }
6554
6555 DataLayout dataLayout = DataLayout::closest(*this);
6556 auto sourceElementBits =
6557 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
6558 auto resultElementBits =
6559 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
6560
6561 if (sourceVectorType.getRank() == 0) {
6562 if (sourceElementBits != resultElementBits)
6563 return emitOpError("source/result bitwidth of the 0-D vector element "
6564 "types must be equal");
6565 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6566 resultElementBits * resultVectorType.getShape().back()) {
6567 return emitOpError(
6568 "source/result bitwidth of the minor 1-D vectors must be equal");
6569 }
6570
6571 return success();
6572}
6573
6574OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6575 // Nop cast.
6576 if (getSource().getType() == getResult().getType())
6577 return getSource();
6578
6579 // Canceling bitcasts.
6580 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6581 if (getResult().getType() == otherOp.getSource().getType())
6582 return otherOp.getSource();
6583
6584 setOperand(otherOp.getSource());
6585 return getResult();
6586 }
6587
6588 Attribute sourceConstant = adaptor.getSource();
6589 if (!sourceConstant)
6590 return {};
6591
6592 Type srcElemType = getSourceVectorType().getElementType();
6593 Type dstElemType = getResultVectorType().getElementType();
6594
6595 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6596 if (floatPack.isSplat()) {
6597 auto splat = floatPack.getSplatValue<FloatAttr>();
6598
6599 // Casting fp16 into fp32.
6600 if (srcElemType.isF16() && dstElemType.isF32()) {
6601 uint32_t bits = static_cast<uint32_t>(
6602 splat.getValue().bitcastToAPInt().getZExtValue());
6603 // Duplicate the 16-bit pattern.
6604 bits = (bits << 16) | (bits & 0xffff);
6605 APInt intBits(32, bits);
6606 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6607 return DenseElementsAttr::get(getResultVectorType(), floatBits);
6608 }
6609 }
6610 }
6611
6612 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6613 if (intPack.isSplat()) {
6614 auto splat = intPack.getSplatValue<IntegerAttr>();
6615
6616 if (llvm::isa<IntegerType>(dstElemType)) {
6617 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6618 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6619
6620 // Casting to a larger integer bit width.
6621 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6622 APInt intBits = splat.getValue().zext(dstBitWidth);
6623
6624 // Duplicate the lower width element.
6625 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6626 intBits = (intBits << srcBitWidth) | intBits;
6627 return DenseElementsAttr::get(getResultVectorType(), intBits);
6628 }
6629 }
6630 }
6631 }
6632
6633 return {};
6634}
6635
6636//===----------------------------------------------------------------------===//
6637// TypeCastOp
6638//===----------------------------------------------------------------------===//
6639
6640static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6641 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6642 SmallVector<int64_t, 8> res(memRefType.getShape());
6643 if (vectorType)
6644 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6645 return res;
6646}
6647
6648/// Build the canonical memRefType with a single vector.
6649/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6650void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6651 Value source) {
6652 result.addOperands(source);
6653 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
6654 VectorType vectorType =
6655 VectorType::get(extractShape(memRefType),
6657 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6658 memRefType.getMemorySpace()));
6659}
6660
6661LogicalResult TypeCastOp::verify() {
6662 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6663 if (!canonicalType.getLayout().isIdentity())
6664 return emitOpError("expects operand to be a memref with identity layout");
6665 if (!getResultMemRefType().getLayout().isIdentity())
6666 return emitOpError("expects result to be a memref with identity layout");
6667 if (getResultMemRefType().getMemorySpace() !=
6668 getMemRefType().getMemorySpace())
6669 return emitOpError("expects result in same memory space");
6670
6671 auto sourceType = getMemRefType();
6672 auto resultType = getResultMemRefType();
6673 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
6675 return emitOpError(
6676 "expects result and operand with same underlying scalar type: ")
6677 << resultType;
6678 if (extractShape(sourceType) != extractShape(resultType))
6679 return emitOpError(
6680 "expects concatenated result and operand shapes to be equal: ")
6681 << resultType;
6682 return success();
6683}
6684
6685//===----------------------------------------------------------------------===//
6686// TransposeOp
6687//===----------------------------------------------------------------------===//
6688
6689void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6690 Value vector, ArrayRef<int64_t> permutation) {
6691 VectorType vt = llvm::cast<VectorType>(vector.getType());
6692 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6693 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6694 for (unsigned i = 0; i < permutation.size(); ++i) {
6695 transposedShape[i] = vt.getShape()[permutation[i]];
6696 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6697 }
6698
6699 result.addOperands(vector);
6700 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6701 transposedScalableDims));
6702 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6703 builder.getDenseI64ArrayAttr(permutation));
6704}
6705
6706OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6707 // Eliminate splat constant transpose ops.
6708 if (auto splat =
6709 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6710 return splat.reshape(getResultVectorType());
6711
6712 // Eliminate poison transpose ops.
6713 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6714 return ub::PoisonAttr::get(getContext());
6715
6716 // Eliminate identity transposes, and more generally any transposes that
6717 // preserves the shape without permuting elements.
6718 //
6719 // Examples of what to fold:
6720 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6721 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6722 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6723 //
6724 // Example of what NOT to fold:
6725 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6726 //
6727 if (getSourceVectorType() == getResultVectorType() &&
6728 isOrderPreserving(*this))
6729 return getVector();
6730
6731 return {};
6732}
6733
6734LogicalResult vector::TransposeOp::verify() {
6735 VectorType vectorType = getSourceVectorType();
6736 VectorType resultType = getResultVectorType();
6737 int64_t rank = resultType.getRank();
6738 if (vectorType.getRank() != rank)
6739 return emitOpError("vector result rank mismatch: ") << rank;
6740 // Verify transposition array.
6741 ArrayRef<int64_t> perm = getPermutation();
6742 int64_t size = perm.size();
6743 if (rank != size)
6744 return emitOpError("transposition length mismatch: ") << size;
6745 SmallVector<bool, 8> seen(rank, false);
6746 for (const auto &ta : llvm::enumerate(perm)) {
6747 if (ta.value() < 0 || ta.value() >= rank)
6748 return emitOpError("transposition index out of range: ") << ta.value();
6749 if (seen[ta.value()])
6750 return emitOpError("duplicate position index: ") << ta.value();
6751 seen[ta.value()] = true;
6752 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6753 return emitOpError("dimension size mismatch at: ") << ta.value();
6754 }
6755 return success();
6756}
6757
6758std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6759 return llvm::to_vector<4>(getResultVectorType().getShape());
6760}
6761
6762void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6763 SetIntRangeFn setResultRanges) {
6764 setResultRanges(getResult(), argRanges.front());
6765}
6766
6767namespace {
6768
6769// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6770class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6771public:
6772 using Base::Base;
6773
6774 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6775 PatternRewriter &rewriter) const override {
6776 // Composes two permutations: result[i] = permutation1[permutation2[i]].
6777 auto composePermutations = [](ArrayRef<int64_t> permutation1,
6778 ArrayRef<int64_t> permutation2) {
6779 SmallVector<int64_t, 4> result;
6780 for (auto index : permutation2)
6781 result.push_back(permutation1[index]);
6782 return result;
6783 };
6784
6785 // Return if the input of 'transposeOp' is not defined by another transpose.
6786 vector::TransposeOp parentTransposeOp =
6787 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6788 if (!parentTransposeOp)
6789 return failure();
6790
6791 SmallVector<int64_t, 4> permutation = composePermutations(
6792 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6793 // Replace 'transposeOp' with a new transpose operation.
6794 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6795 transposeOp, transposeOp.getResult().getType(),
6796 parentTransposeOp.getVector(), permutation);
6797 return success();
6798 }
6799};
6800
6801/// Replace transpose(splat-like(v)) with broadcast(v)
6802class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6803public:
6804 using Base::Base;
6805
6806 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6807 PatternRewriter &rewriter) const override {
6808 Value splat = getScalarSplatSource(transposeOp.getVector());
6809 if (!splat)
6810 return failure();
6811
6812 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6813 transposeOp, transposeOp.getResultVectorType(), splat);
6814 return success();
6815 }
6816};
6817
6818/// Folds transpose(create_mask) into a new transposed create_mask.
6819class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6820public:
6821 using Base::Base;
6822
6823 LogicalResult matchAndRewrite(TransposeOp transpOp,
6824 PatternRewriter &rewriter) const override {
6825 Value transposeSrc = transpOp.getVector();
6826 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6827 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6828 if (!createMaskOp && !constantMaskOp)
6829 return failure();
6830
6831 // Get the transpose permutation and apply it to the vector.create_mask or
6832 // vector.constant_mask operands.
6833 ArrayRef<int64_t> permutation = transpOp.getPermutation();
6834
6835 if (createMaskOp) {
6836 auto maskOperands = createMaskOp.getOperands();
6837 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6838 applyPermutationToVector(newOperands, permutation);
6839
6840 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6841 transpOp, transpOp.getResultVectorType(), newOperands);
6842 return success();
6843 }
6844
6845 // ConstantMaskOp case.
6846 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6847 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6848
6849 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6850 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6851 return success();
6852 }
6853};
6854
6855/// Folds transpose(shape_cast) into a new shape_cast.
6856class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6857public:
6858 using Base::Base;
6859
6860 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6861 PatternRewriter &rewriter) const override {
6862 auto shapeCastOp =
6863 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6864 if (!shapeCastOp)
6865 return failure();
6866 if (!isOrderPreserving(transposeOp))
6867 return failure();
6868
6869 VectorType resultType = transposeOp.getType();
6870
6871 // We don't need to check isValidShapeCast at this point, because it is
6872 // guaranteed that merging the transpose into the the shape_cast is a valid
6873 // shape_cast, because the transpose just inserts/removes ones.
6874
6875 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6876 shapeCastOp.getSource());
6877 return success();
6878 }
6879};
6880
6881/// Folds transpose(from_elements(...)) into a new from_elements with permuted
6882/// operands matching the transposed shape.
6883///
6884/// Example:
6885///
6886/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
6887/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
6888/// vector<3x2xi32>
6889///
6890/// becomes ->
6891///
6892/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
6893/// vector<3x2xi32>
6894///
6895class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
6896public:
6897 using Base::Base;
6898 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6899 PatternRewriter &rewriter) const override {
6900 auto fromElementsOp =
6901 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
6902 if (!fromElementsOp)
6903 return failure();
6904
6905 VectorType srcTy = fromElementsOp.getDest().getType();
6906 VectorType dstTy = transposeOp.getType();
6907
6908 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
6909 int64_t rank = srcTy.getRank();
6910
6911 // Build inverse permutation to map destination indices back to source.
6912 SmallVector<int64_t> inversePerm(rank, 0);
6913 for (int64_t i = 0; i < rank; ++i)
6914 inversePerm[permutation[i]] = i;
6915
6916 ArrayRef<int64_t> srcShape = srcTy.getShape();
6917 ArrayRef<int64_t> dstShape = dstTy.getShape();
6918 SmallVector<int64_t> srcIdx(rank, 0);
6919 SmallVector<int64_t> dstIdx(rank, 0);
6920 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
6921 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
6922
6923 auto elementsOld = fromElementsOp.getElements();
6924 SmallVector<Value> elementsNew;
6925 int64_t dstNumElements = dstTy.getNumElements();
6926 elementsNew.reserve(dstNumElements);
6927
6928 // For each element in destination row-major order, pick the corresponding
6929 // source element.
6930 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
6931 // Pick the destination element index.
6932 dstIdx = delinearize(linearIdx, dstStrides);
6933 // Map the destination element index to the source element index.
6934 for (int64_t j = 0; j < rank; ++j)
6935 srcIdx[j] = dstIdx[inversePerm[j]];
6936 // Linearize the source element index.
6937 int64_t srcLin = linearize(srcIdx, srcStrides);
6938 // Add the source element to the new elements.
6939 elementsNew.push_back(elementsOld[srcLin]);
6940 }
6941
6942 rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
6943 elementsNew);
6944 return success();
6945 }
6946};
6947
6948/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6949/// 'order preserving', where 'order preserving' means the flattened
6950/// inputs and outputs of the transpose have identical (numerical) values.
6951///
6952/// Example:
6953/// ```
6954/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6955/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6956/// to vector<8x1xi32>
6957/// ```
6958/// can be rewritten as the equivalent
6959/// ```
6960/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6961/// ```
6962/// The algorithm works by partitioning dimensions into groups that can be
6963/// locally permuted while preserving order, and checks that the transpose
6964/// only permutes within these groups.
6965///
6966/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6967/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6968/// broadcasting from 1x1x4x1x1x7.
6969/// ^^^ ^ ^^^ ^
6970/// groups: 0 1 2 3
6971/// Order preserving permutations for this example are ones that only permute
6972/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6973class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6974public:
6975 using Base::Base;
6976 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6977 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6978
6979 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6980 PatternRewriter &rewriter) const override {
6981
6982 vector::BroadcastOp broadcast =
6983 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6984 if (!broadcast) {
6985 return rewriter.notifyMatchFailure(transpose,
6986 "not preceded by a broadcast");
6987 }
6988
6989 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6990 VectorType outputType = transpose.getResultVectorType();
6991
6992 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6993 bool inputIsScalar = !inputType;
6994 if (inputIsScalar) {
6995 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6996 broadcast.getSource());
6997 return success();
6998 }
6999
7000 ArrayRef<int64_t> permutation = transpose.getPermutation();
7001 ArrayRef<int64_t> inputShape = inputType.getShape();
7002 int64_t inputRank = inputType.getRank();
7003 int64_t outputRank = transpose.getType().getRank();
7004 int64_t deltaRank = outputRank - inputRank;
7005
7006 int low = 0;
7007 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7008 bool notOne = inputShape[inputIndex] != 1;
7009 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7010 bool groupEndFound = notOne || prevNotOne;
7011 if (groupEndFound) {
7012 int high = inputIndex + deltaRank;
7013 // Return failure if not all permutation destinations for indices in
7014 // [low, high) are in [low, high), i.e. the permutation is not local to
7015 // the group.
7016 for (int i = low; i < high; ++i) {
7017 if (permutation[i] < low || permutation[i] >= high) {
7018 return rewriter.notifyMatchFailure(
7019 transpose, "permutation not local to group");
7020 }
7021 }
7022 low = high;
7023 }
7024 }
7025
7026 // We don't need to check the final group [low, outputRank) because if it is
7027 // not locally bound, there must be a preceding group that already failed
7028 // the check (impossible to have just 1 non-locally bound group).
7029
7030 // The preceding logic also ensures that at this point, the output of the
7031 // transpose is definitely broadcastable from the input shape, assert so:
7032 assert(vector::isBroadcastableTo(inputType, outputType) ==
7033 vector::BroadcastableToResult::Success &&
7034 "not broadcastable directly to transpose output");
7035
7036 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7037 broadcast.getSource());
7038
7039 return success();
7040 }
7041};
7042
7043} // namespace
7044
7045void vector::TransposeOp::getCanonicalizationPatterns(
7046 RewritePatternSet &results, MLIRContext *context) {
7047 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7048 FoldTransposeSplat, FoldTransposeFromElements,
7049 FoldTransposeBroadcast>(context);
7050}
7051
7052//===----------------------------------------------------------------------===//
7053// ConstantMaskOp
7054//===----------------------------------------------------------------------===//
7055
7056void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
7057 VectorType type, ConstantMaskKind kind) {
7058 assert(kind == ConstantMaskKind::AllTrue ||
7059 kind == ConstantMaskKind::AllFalse);
7060 build(builder, result, type,
7061 kind == ConstantMaskKind::AllTrue
7062 ? type.getShape()
7063 : SmallVector<int64_t>(type.getRank(), 0));
7064}
7065
7066LogicalResult ConstantMaskOp::verify() {
7067 auto resultType = llvm::cast<VectorType>(getResult().getType());
7068 // Check the corner case of 0-D vectors first.
7069 if (resultType.getRank() == 0) {
7070 if (getMaskDimSizes().size() != 1)
7071 return emitError("array attr must have length 1 for 0-D vectors");
7072 auto dim = getMaskDimSizes()[0];
7073 if (dim != 0 && dim != 1)
7074 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
7075 return success();
7076 }
7077
7078 // Verify that array attr size matches the rank of the vector result.
7079 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
7080 return emitOpError(
7081 "must specify array attr of size equal vector result rank");
7082 // Verify that each array attr element is in bounds of corresponding vector
7083 // result dimension size.
7084 auto resultShape = resultType.getShape();
7085 auto resultScalableDims = resultType.getScalableDims();
7086 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7087 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7088 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7089 return emitOpError(
7090 "array attr of size out of bounds of vector result dimension size");
7091 if (resultScalableDims[index] && maskDimSize != 0 &&
7092 maskDimSize != resultShape[index])
7093 return emitOpError(
7094 "only supports 'none set' or 'all set' scalable dimensions");
7095 }
7096 // Verify that if one mask dim size is zero, they all should be zero (because
7097 // the mask region is a conjunction of each mask dimension interval).
7098 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7099 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
7100 if (anyZeros && !allZeros)
7101 return emitOpError("expected all mask dim sizes to be zeros, "
7102 "as a result of conjunction with zero mask dim");
7103 return success();
7104}
7105
7106bool ConstantMaskOp::isAllOnesMask() {
7107 auto resultType = getVectorType();
7108 // Check the corner case of 0-D vectors first.
7109 if (resultType.getRank() == 0) {
7110 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7111 return getMaskDimSizes()[0] == 1;
7112 }
7113 for (const auto [resultSize, maskDimSize] :
7114 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7115 if (maskDimSize < resultSize)
7116 return false;
7117 }
7118 return true;
7119}
7120
7121OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7122 ArrayRef<int64_t> bounds = getMaskDimSizes();
7123 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7124
7125 auto createBoolSplat = [&](bool x) {
7126 return SplatElementsAttr::get(getVectorType(),
7128 };
7129
7130 // Check the corner case of 0-D vectors first.
7131 if (vectorSizes.empty()) {
7132 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7133 return createBoolSplat(bounds[0] == 1);
7134 }
7135 // Fold vector.constant_mask to splat if possible.
7136 if (bounds == vectorSizes)
7137 return createBoolSplat(true);
7138 if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7139 return createBoolSplat(false);
7140 return OpFoldResult();
7141}
7142
7143//===----------------------------------------------------------------------===//
7144// CreateMaskOp
7145//===----------------------------------------------------------------------===//
7146
7147void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7148 VectorType type,
7149 ArrayRef<OpFoldResult> mixedOperands) {
7150 SmallVector<Value> operands =
7151 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7152 build(builder, result, type, operands);
7153}
7154
7155LogicalResult CreateMaskOp::verify() {
7156 auto vectorType = llvm::cast<VectorType>(getResult().getType());
7157 // Verify that an operand was specified for each result vector each dimension.
7158 if (vectorType.getRank() == 0) {
7159 if (getNumOperands() != 1)
7160 return emitOpError(
7161 "must specify exactly one operand for 0-D create_mask");
7162 } else if (getNumOperands() !=
7163 llvm::cast<VectorType>(getResult().getType()).getRank()) {
7164 return emitOpError(
7165 "must specify an operand for each result vector dimension");
7166 }
7167 return success();
7168}
7169
7170namespace {
7171
7172/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7173///
7174/// Ex 1:
7175/// %c2 = arith.constant 2 : index
7176/// %c3 = arith.constant 3 : index
7177/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7178/// Becomes:
7179/// vector.constant_mask [3, 2] : vector<4x3xi1>
7180///
7181/// Ex 2:
7182/// %c_neg_1 = arith.constant -1 : index
7183/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7184/// becomes:
7185/// vector.constant_mask [0] : vector<[8]xi1>
7186///
7187/// Ex 3:
7188/// %c8 = arith.constant 8 : index
7189/// %c16 = arith.constant 16 : index
7190/// %0 = vector.vscale
7191/// %1 = arith.muli %0, %c16 : index
7192/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7193/// becomes:
7194/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7195class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7196public:
7197 using Base::Base;
7198
7199 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7200 PatternRewriter &rewriter) const override {
7201 VectorType maskType = createMaskOp.getVectorType();
7202 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7203 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7204
7205 // Special case: Rank zero shape.
7206 constexpr std::array<int64_t, 1> rankZeroShape{1};
7207 constexpr std::array<bool, 1> rankZeroScalableDims{false};
7208 if (maskType.getRank() == 0) {
7209 maskTypeDimSizes = rankZeroShape;
7210 maskTypeDimScalableFlags = rankZeroScalableDims;
7211 }
7212
7213 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7214 // collect the `constantDims` (for the ConstantMaskOp).
7215 SmallVector<int64_t, 4> constantDims;
7216 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7217 if (auto intSize = getConstantIntValue(dimSize)) {
7218 // Constant value.
7219 // If the mask dim is non-scalable this can be any value.
7220 // If the mask dim is scalable only zero (all-false) is supported.
7221 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7222 return failure();
7223 constantDims.push_back(*intSize);
7224 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7225 // Constant vscale multiple (e.g. 4 x vscale).
7226 // Must be all-true to fold to a ConstantMask.
7227 if (vscaleMultiplier < maskTypeDimSizes[i])
7228 return failure();
7229 constantDims.push_back(*vscaleMultiplier);
7230 } else {
7231 return failure();
7232 }
7233 }
7234
7235 // Clamp values to constant_mask bounds.
7236 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7237 value = std::clamp<int64_t>(value, 0, maskDimSize);
7238
7239 // If one of dim sizes is zero, set all dims to zero.
7240 if (llvm::is_contained(constantDims, 0))
7241 constantDims.assign(constantDims.size(), 0);
7242
7243 // Replace 'createMaskOp' with ConstantMaskOp.
7244 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7245 constantDims);
7246 return success();
7247 }
7248};
7249
7250} // namespace
7251
7252void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7253 MLIRContext *context) {
7254 results.add<CreateMaskFolder>(context);
7255}
7256
7257//===----------------------------------------------------------------------===//
7258// MaskOp
7259//===----------------------------------------------------------------------===//
7260
7261void MaskOp::build(
7262 OpBuilder &builder, OperationState &result, Value mask,
7263 Operation *maskableOp,
7264 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7265 assert(maskRegionBuilder &&
7266 "builder callback for 'maskRegion' must be present");
7267
7268 result.addOperands(mask);
7269 OpBuilder::InsertionGuard guard(builder);
7270 Region *maskRegion = result.addRegion();
7271 builder.createBlock(maskRegion);
7272 maskRegionBuilder(builder, maskableOp);
7273}
7274
7275void MaskOp::build(
7276 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7277 Value mask, Operation *maskableOp,
7278 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7279 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7280 maskRegionBuilder);
7281}
7282
7283void MaskOp::build(
7284 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7285 Value mask, Value passthru, Operation *maskableOp,
7286 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7287 build(builder, result, mask, maskableOp, maskRegionBuilder);
7288 if (passthru)
7289 result.addOperands(passthru);
7290 result.addTypes(resultTypes);
7291}
7292
7293ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7294 // Create the op region.
7295 result.regions.reserve(1);
7296 Region &maskRegion = *result.addRegion();
7297
7298 auto &builder = parser.getBuilder();
7299
7300 // Parse all the operands.
7301 OpAsmParser::UnresolvedOperand mask;
7302 if (parser.parseOperand(mask))
7303 return failure();
7304
7305 // Optional passthru operand.
7306 OpAsmParser::UnresolvedOperand passthru;
7307 ParseResult parsePassthru = parser.parseOptionalComma();
7308 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7309 return failure();
7310
7311 // Parse op region.
7312 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7313 return failure();
7314
7315 MaskOp::ensureTerminator(maskRegion, builder, result.location);
7316
7317 // Parse the optional attribute list.
7318 if (parser.parseOptionalAttrDict(result.attributes))
7319 return failure();
7320
7321 // Parse all the types.
7322 Type maskType;
7323 if (parser.parseColonType(maskType))
7324 return failure();
7325
7326 SmallVector<Type> resultTypes;
7327 if (parser.parseOptionalArrowTypeList(resultTypes))
7328 return failure();
7329 result.types.append(resultTypes);
7330
7331 // Resolve operands.
7332 if (parser.resolveOperand(mask, maskType, result.operands))
7333 return failure();
7334
7335 if (parsePassthru.succeeded()) {
7336 if (resultTypes.empty())
7337 return parser.emitError(
7338 parser.getNameLoc(),
7339 "expects a result if passthru operand is provided");
7340
7341 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7342 return failure();
7343 }
7344
7345 return success();
7346}
7347
7348void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7349 p << " " << getMask();
7350 if (getPassthru())
7351 p << ", " << getPassthru();
7352
7353 // Print single masked operation and skip terminator.
7354 p << " { ";
7355 Block *singleBlock = &getMaskRegion().getBlocks().front();
7356 if (singleBlock && !singleBlock->getOperations().empty())
7357 p.printCustomOrGenericOp(&singleBlock->front());
7358 p << " }";
7359
7360 p.printOptionalAttrDict(getOperation()->getAttrs());
7361
7362 p << " : " << getMask().getType();
7363 if (getNumResults() > 0)
7364 p << " -> " << getResultTypes();
7365}
7366
7367void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7368 // 1. For an empty `vector.mask`, create a default terminator.
7369 if (region.empty() || region.front().empty()) {
7370 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7371 MaskOp>::ensureTerminator(region, builder, loc);
7372 return;
7373 }
7374
7375 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7376 Block &block = region.front();
7377 if (isa<vector::YieldOp>(block.back()))
7378 return;
7379
7380 // 3. For a non-empty `vector.mask` without an explicit terminator:
7381
7382 // Create default terminator if the number of masked operations is not
7383 // one. This case will trigger a verification failure.
7384 if (block.getOperations().size() != 1) {
7385 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7386 MaskOp>::ensureTerminator(region, builder, loc);
7387 return;
7388 }
7389
7390 // Create a terminator that yields the results from the masked operation.
7391 OpBuilder opBuilder(builder.getContext());
7392 Operation *maskedOp = &block.front();
7393 opBuilder.setInsertionPointToEnd(&block);
7394 vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7395}
7396
7397LogicalResult MaskOp::verify() {
7398 // Structural checks.
7399 Block &block = getMaskRegion().getBlocks().front();
7400 if (block.getOperations().empty())
7401 return emitOpError("expects a terminator within the mask region");
7402
7403 unsigned numMaskRegionOps = block.getOperations().size();
7404 if (numMaskRegionOps > 2)
7405 return emitOpError("expects only one operation to mask");
7406
7407 // Terminator checks.
7408 auto terminator = dyn_cast<vector::YieldOp>(block.back());
7409 if (!terminator)
7410 return emitOpError("expects a terminator within the mask region");
7411
7412 if (terminator->getNumOperands() != getNumResults())
7413 return emitOpError(
7414 "expects number of results to match mask region yielded values");
7415
7416 // Empty vector.mask. Nothing else to check.
7417 if (numMaskRegionOps == 1)
7418 return success();
7419
7420 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7421 if (!maskableOp)
7422 return emitOpError("expects a MaskableOpInterface within the mask region");
7423
7424 // Result checks.
7425 if (maskableOp->getNumResults() != getNumResults())
7426 return emitOpError("expects number of results to match maskable operation "
7427 "number of results");
7428
7429 if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7430 return emitOpError("expects all the results from the MaskableOpInterface "
7431 "to match all the values returned by the terminator");
7432
7433 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7434 return emitOpError(
7435 "expects result type to match maskable operation result type");
7436
7437 if (llvm::count_if(maskableOp->getResultTypes(),
7438 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7439 return emitOpError("multiple vector results not supported");
7440
7441 // Mask checks.
7442 Type expectedMaskType = maskableOp.getExpectedMaskType();
7443 if (getMask().getType() != expectedMaskType)
7444 return emitOpError("expects a ")
7445 << expectedMaskType << " mask for the maskable operation";
7446
7447 // Passthru checks.
7448 Value passthru = getPassthru();
7449 if (passthru) {
7450 if (!maskableOp.supportsPassthru())
7451 return emitOpError(
7452 "doesn't expect a passthru argument for this maskable operation");
7453
7454 if (maskableOp->getNumResults() != 1)
7455 return emitOpError("expects result when passthru argument is provided");
7456
7457 if (passthru.getType() != maskableOp->getResultTypes()[0])
7458 return emitOpError("expects passthru type to match result type");
7459 }
7460
7461 return success();
7462}
7463
7464/// Folds empty `vector.mask` with no passthru operand and with or without
7465/// return values. For example:
7466///
7467/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7468/// vector<8xi1> -> vector<8xf32>
7469/// %1 = user_op %0 : vector<8xf32>
7470///
7471/// becomes:
7472///
7473/// %0 = user_op %a : vector<8xf32>
7474///
7475/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7476/// as it requires creating new operations.
7477
7478static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7479 SmallVectorImpl<OpFoldResult> &results) {
7480 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7481 return failure();
7482
7483 Block *block = maskOp.getMaskBlock();
7484 auto terminator = cast<vector::YieldOp>(block->front());
7485 if (terminator.getNumOperands() == 0) {
7486 // `vector.mask` has no results, just remove the `vector.mask`.
7487 return success();
7488 }
7489
7490 // `vector.mask` has results, propagate the results.
7491 llvm::append_range(results, terminator.getOperands());
7492 return success();
7493}
7494
7495LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7496 SmallVectorImpl<OpFoldResult> &results) {
7497 if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7498 return success();
7499
7500 MaskFormat maskFormat = getMaskFormat(getMask());
7501 if (maskFormat != MaskFormat::AllTrue)
7502 return failure();
7503
7504 // Move maskable operation outside of the `vector.mask` region.
7505 Operation *maskableOp = getMaskableOp();
7506 maskableOp->dropAllUses();
7507 maskableOp->moveBefore(getOperation());
7508
7509 llvm::append_range(results, maskableOp->getResults());
7510 return success();
7511}
7512
7513/// Canonialize empty `vector.mask` operations that can't be handled in
7514/// `VectorMask::fold` as they require creating new operations.
7515///
7516/// Example 1: Empty `vector.mask` with passthru operand.
7517///
7518/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7519/// vector<8xi1> -> vector<8xf32>
7520///
7521/// becomes:
7522///
7523/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7524///
7525class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7526 using Base::Base;
7527
7528 LogicalResult matchAndRewrite(MaskOp maskOp,
7529 PatternRewriter &rewriter) const override {
7530 if (!maskOp.isEmpty())
7531 return failure();
7532
7533 if (!maskOp.hasPassthru())
7534 return failure();
7535
7536 Block *block = maskOp.getMaskBlock();
7537 auto terminator = cast<vector::YieldOp>(block->front());
7538 assert(terminator.getNumOperands() == 1 &&
7539 "expected one result when passthru is provided");
7540
7541 rewriter.replaceOpWithNewOp<arith::SelectOp>(
7542 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7543 terminator.getOperand(0), maskOp.getPassthru());
7544
7545 return success();
7546 }
7547};
7548
7549void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7550 MLIRContext *context) {
7551 results.add<CanonializeEmptyMaskOp>(context);
7552}
7553
7554// MaskingOpInterface definitions.
7555
7556/// Returns the operation masked by this 'vector.mask'.
7557Operation *MaskOp::getMaskableOp() {
7558 Block *block = getMaskBlock();
7559 if (block->getOperations().size() < 2)
7560 return nullptr;
7561
7562 return &block->front();
7563}
7564
7565/// Returns true if 'vector.mask' has a passthru value.
7566bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
7567
7568//===----------------------------------------------------------------------===//
7569// ScanOp
7570//===----------------------------------------------------------------------===//
7571
7572LogicalResult ScanOp::verify() {
7573 VectorType srcType = getSourceType();
7574 VectorType initialType = getInitialValueType();
7575 // Check reduction dimension < rank.
7576 int64_t srcRank = srcType.getRank();
7577 int64_t reductionDim = getReductionDim();
7578 if (reductionDim >= srcRank)
7579 return emitOpError("reduction dimension ")
7580 << reductionDim << " has to be less than " << srcRank;
7581
7582 // Check that rank(initial_value) = rank(src) - 1.
7583 int64_t initialValueRank = initialType.getRank();
7584 if (initialValueRank != srcRank - 1)
7585 return emitOpError("initial value rank ")
7586 << initialValueRank << " has to be equal to " << srcRank - 1;
7587
7588 // Check shapes of initial value and src.
7589 ArrayRef<int64_t> srcShape = srcType.getShape();
7590 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7591 SmallVector<int64_t> expectedShape;
7592 for (int i = 0; i < srcRank; i++) {
7593 if (i != reductionDim)
7594 expectedShape.push_back(srcShape[i]);
7595 }
7596 if (!llvm::equal(initialValueShapes, expectedShape)) {
7597 return emitOpError("incompatible input/initial value shapes");
7598 }
7599
7600 // Verify supported reduction kind.
7601 Type eltType = getDestType().getElementType();
7602 if (!isSupportedCombiningKind(getKind(), eltType))
7603 return emitOpError("unsupported reduction type ")
7604 << eltType << " for kind '" << stringifyCombiningKind(getKind())
7605 << "'";
7606
7607 return success();
7608}
7609
7611 RewritePatternSet &patterns, PatternBenefit benefit) {
7612 patterns
7613 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7614 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7615 StridedSliceConstantMaskFolder, TransposeFolder>(
7616 patterns.getContext(), benefit);
7617}
7618
7619Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
7620 CombiningKind kind, Value v1, Value acc,
7621 arith::FastMathFlagsAttr fastmath,
7622 Value mask) {
7623 Type t1 = getElementTypeOrSelf(v1.getType());
7624 Type tAcc = getElementTypeOrSelf(acc.getType());
7625 Value result;
7626
7627 switch (kind) {
7628 case CombiningKind::ADD:
7629 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7630 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
7631 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7632 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7633 else
7634 llvm_unreachable("invalid value types for ADD reduction");
7635 break;
7636 case CombiningKind::AND:
7637 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7638 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
7639 break;
7640 case CombiningKind::MAXNUMF:
7641 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7642 "expected float values");
7643 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7644 break;
7645 case CombiningKind::MAXIMUMF:
7646 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7647 "expected float values");
7648 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7649 break;
7650 case CombiningKind::MINNUMF:
7651 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7652 "expected float values");
7653 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7654 break;
7655 case CombiningKind::MINIMUMF:
7656 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7657 "expected float values");
7658 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7659 break;
7660 case CombiningKind::MAXSI:
7661 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7662 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7663 break;
7664 case CombiningKind::MINSI:
7665 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7666 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7667 break;
7668 case CombiningKind::MAXUI:
7669 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7670 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7671 break;
7672 case CombiningKind::MINUI:
7673 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7674 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7675 break;
7676 case CombiningKind::MUL:
7677 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7678 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
7679 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7680 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7681 else
7682 llvm_unreachable("invalid value types for MUL reduction");
7683 break;
7684 case CombiningKind::OR:
7685 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7686 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
7687 break;
7688 case CombiningKind::XOR:
7689 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7690 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7691 break;
7692 };
7693
7694 assert(result && "unknown CombiningKind");
7695 return selectPassthru(b, mask, result, acc);
7696}
7697
7698//===----------------------------------------------------------------------===//
7699// StepOp
7700//===----------------------------------------------------------------------===//
7701
7702void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7703 SetIntRangeFn setResultRanges) {
7704 auto resultType = cast<VectorType>(getType());
7705 if (resultType.isScalable()) {
7706 return;
7707 }
7708 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7709 APInt zero(bitwidth, 0);
7710 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7711 ConstantIntRanges result = {zero, high, zero, high};
7712 setResultRanges(getResult(), result);
7713}
7714
7715namespace {
7716
7717/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
7718/// constant large enough such that the result is the same at all indices.
7719///
7720/// For example, rewrite the 'greater than' comparison below,
7721///
7722/// ```mlir
7723/// %cst = arith.constant dense<7> : vector<3xindex>
7724/// %stp = vector.step : vector<3xindex>
7725/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
7726/// ```
7727///
7728/// as,
7729///
7730/// ```mlir
7731/// %out = arith.constant dense<false> : vector<3xi1>.
7732/// ```
7733///
7734/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
7735/// is false at ALL indices we fold. If the constant was 1, then
7736/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
7737/// conservatively preferring the 'compact' vector.step representation.
7738///
7739/// Note: this folder only works for the case where the constant (`%cst` above)
7740/// is the second operand of the comparison. The arith.cmpi canonicalizer will
7741/// ensure that constants are always second (on the right).
7742struct StepCompareFolder : public OpRewritePattern<StepOp> {
7743 using Base::Base;
7744
7745 LogicalResult matchAndRewrite(StepOp stepOp,
7746 PatternRewriter &rewriter) const override {
7747 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7748
7749 for (OpOperand &use : stepOp.getResult().getUses()) {
7750 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7751 if (!cmpiOp)
7752 continue;
7753
7754 // arith.cmpi canonicalizer makes constants final operands.
7755 const unsigned stepOperandNumber = use.getOperandNumber();
7756 if (stepOperandNumber != 0)
7757 continue;
7758
7759 // Check that operand 1 is a constant.
7760 unsigned constOperandNumber = 1;
7761 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7762 std::optional<int64_t> maybeConstValue =
7763 getConstantIntValue(otherOperand);
7764 if (!maybeConstValue.has_value())
7765 continue;
7766
7767 int64_t constValue = maybeConstValue.value();
7768 arith::CmpIPredicate pred = cmpiOp.getPredicate();
7769
7770 auto maybeSplat = [&]() -> std::optional<bool> {
7771 // Handle ult (unsigned less than) and uge (unsigned greater equal).
7772 if ((pred == arith::CmpIPredicate::ult ||
7773 pred == arith::CmpIPredicate::uge) &&
7774 stepSize <= constValue)
7775 return pred == arith::CmpIPredicate::ult;
7776
7777 // Handle ule and ugt.
7778 if ((pred == arith::CmpIPredicate::ule ||
7779 pred == arith::CmpIPredicate::ugt) &&
7780 stepSize - 1 <= constValue) {
7781 return pred == arith::CmpIPredicate::ule;
7782 }
7783
7784 // Handle eq and ne.
7785 if ((pred == arith::CmpIPredicate::eq ||
7786 pred == arith::CmpIPredicate::ne) &&
7787 stepSize <= constValue)
7788 return pred == arith::CmpIPredicate::ne;
7789
7790 return std::nullopt;
7791 }();
7792
7793 if (!maybeSplat.has_value())
7794 continue;
7795
7796 rewriter.setInsertionPointAfter(cmpiOp);
7797
7798 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7799 if (!type)
7800 continue;
7801
7802 auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
7803 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7804 type, boolAttr);
7805
7806 rewriter.replaceOp(cmpiOp, splat);
7807 return success();
7808 }
7809
7810 return failure();
7811 }
7812};
7813} // namespace
7814
7815void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
7816 MLIRContext *context) {
7817 results.add<StepCompareFolder>(context);
7818}
7819
7820//===----------------------------------------------------------------------===//
7821// Vector Masking Utilities
7822//===----------------------------------------------------------------------===//
7823
7824/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
7825/// as masked operation.
7826void mlir::vector::createMaskOpRegion(OpBuilder &builder,
7827 Operation *maskableOp) {
7828 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
7829 Block *insBlock = builder.getInsertionBlock();
7830 // Create a block and move the op to that block.
7831 insBlock->getOperations().splice(
7832 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
7833 YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
7834}
7835
7836/// Creates a vector.mask operation around a maskable operation. Returns the
7837/// vector.mask operation if the mask provided is valid. Otherwise, returns
7838/// the maskable operation itself.
7839Operation *mlir::vector::maskOperation(OpBuilder &builder,
7840 Operation *maskableOp, Value mask,
7841 Value passthru) {
7842 if (!mask)
7843 return maskableOp;
7844 if (passthru)
7845 return MaskOp::create(builder, maskableOp->getLoc(),
7846 maskableOp->getResultTypes(), mask, passthru,
7847 maskableOp, createMaskOpRegion);
7848 return MaskOp::create(builder, maskableOp->getLoc(),
7849 maskableOp->getResultTypes(), mask, maskableOp,
7851}
7852
7853/// Creates a vector select operation that picks values from `newValue` or
7854/// `passthru` for each result vector lane based on `mask`. This utility is used
7855/// to propagate the pass-thru value of vector.mask or for cases where only the
7856/// pass-thru value propagation is needed. VP intrinsics do not support
7857/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
7858/// usually able to match op + select patterns and fold them into a native
7859/// target instructions.
7860Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
7861 Value newValue, Value passthru) {
7862 if (!mask)
7863 return newValue;
7864
7865 return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
7866 mask, newValue, passthru);
7867}
7868
7869//===----------------------------------------------------------------------===//
7870// TableGen'd op method definitions
7871//===----------------------------------------------------------------------===//
7872
7873#define GET_ATTRDEF_CLASSES
7874#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7875
7876#define GET_OP_CLASSES
7877#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
b getContext())
auto load
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition VectorOps.cpp:70
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
Definition VectorOps.cpp:60
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
#define mul(a, b)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
bool empty()
Definition Block.h:148
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
iterator begin()
Definition Block.h:143
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:270
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:442
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
bool isF32() const
Definition Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:112
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:45
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition VectorOps.h:64
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition VectorOps.h:72
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const