MLIR 22.0.0git
VectorOps.cpp
Go to the documentation of this file.
1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
15
26#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Builders.h"
32#include "mlir/IR/IRMapping.h"
36#include "mlir/IR/ValueRange.h"
39#include "mlir/Support/LLVM.h"
41#include "llvm/ADT/ArrayRef.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/StringSet.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/Support/Casting.h"
47
48#include <cassert>
49#include <cstdint>
50#include <numeric>
51
52#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
53// Pull in all enum type and utility function definitions.
54#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
55
56using namespace mlir;
57using namespace mlir::vector;
58
59/// Helper enum to classify mask value.
60enum class MaskFormat {
64};
65
66/// Helper method to classify a mask value. Currently, the method
67/// looks "under the hood" of a constant value with dense attributes
68/// and a constant mask operation (since the client may be called at
69/// various stages during progressive lowering).
71 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
72 // Inspect constant dense values. We count up for bits that
73 // are set, count down for bits that are cleared, and bail
74 // when a mix is detected.
75 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
76 int64_t val = 0;
77 for (bool b : denseElts.getValues<bool>())
78 if (b && val >= 0)
79 val++;
80 else if (!b && val <= 0)
81 val--;
82 else
84 if (val > 0)
86 if (val < 0)
88 }
89 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
90 // Inspect constant mask index. If the index exceeds the
91 // dimension size, all bits are set. If the index is zero
92 // or less, no bits are set.
93 ArrayRef<int64_t> masks = m.getMaskDimSizes();
94 auto shape = m.getType().getShape();
95 bool allTrue = true;
96 bool allFalse = true;
97 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
98 if (maskIdx < dimSize)
99 allTrue = false;
100 if (maskIdx > 0)
101 allFalse = false;
102 }
103 if (allTrue)
104 return MaskFormat::AllTrue;
105 if (allFalse)
107 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
108 // Finds all-false create_masks. An all-true create_mask requires all
109 // dims to be constants, so that'll be folded to a constant_mask, then
110 // detected in the constant_mask case.
111 auto maskOperands = m.getOperands();
112 for (Value operand : maskOperands) {
113 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
114 int64_t dimSize =
115 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
116 if (dimSize <= 0)
118 }
119 }
120 return MaskFormat::Unknown;
121 }
122 return MaskFormat::Unknown;
123}
124
125/// Default callback to build a region with a 'vector.yield' terminator with no
126/// arguments.
128 vector::YieldOp::create(builder, loc);
129}
130
131// Helper for verifying combining kinds in contractions and reductions.
132static bool isSupportedCombiningKind(CombiningKind combiningKind,
133 Type elementType) {
134 switch (combiningKind) {
135 case CombiningKind::ADD:
136 case CombiningKind::MUL:
137 return elementType.isIntOrIndexOrFloat();
138 case CombiningKind::MINUI:
139 case CombiningKind::MINSI:
140 case CombiningKind::MAXUI:
141 case CombiningKind::MAXSI:
142 case CombiningKind::AND:
143 case CombiningKind::OR:
144 case CombiningKind::XOR:
145 return elementType.isIntOrIndex();
146 case CombiningKind::MINNUMF:
147 case CombiningKind::MAXNUMF:
148 case CombiningKind::MINIMUMF:
149 case CombiningKind::MAXIMUMF:
150 return llvm::isa<FloatType>(elementType);
151 }
152 return false;
153}
154
155/// Returns the effective rank of the vector to read/write for Xfer Ops
156///
157/// When the element type of the shaped type is _a scalar_, this will simply
158/// return the rank of the vector ( the result for xfer_read or the value to
159/// store for xfer_write).
160///
161/// When the element type of the base shaped type is _a vector_, returns the
162/// difference between the original vector type and the element type of the
163/// shaped type.
164///
165/// EXAMPLE 1 (element type is _a scalar_):
166/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
167/// - shapedType.getElementType() = f32 (rank 0)
168/// - vectorType.getRank() = 2
169/// - Result = 2 - 0 = 2
170///
171/// EXAMPLE 2 (element type is _a vector_):
172/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
173/// - shapedType.getElementType() = vector<20xf32> (rank 1)
174/// - vectorType.getRank() = 1
175/// - Result = 1 - 1 = 0
176///
177/// This is used to determine the number of minor dimensions for identity maps
178/// in vector transfer Ops.
179static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
180 VectorType vectorType) {
181 unsigned elementVectorRank = 0;
182 VectorType elementVectorType =
183 llvm::dyn_cast<VectorType>(shapedType.getElementType());
184 if (elementVectorType)
185 elementVectorRank += elementVectorType.getRank();
186 return vectorType.getRank() - elementVectorRank;
187}
188
190 VectorType vectorType) {
191 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
192 // TODO: replace once we have 0-d vectors.
193 if (shapedType.getRank() == 0 &&
194 vectorType.getShape() == ArrayRef<int64_t>{1})
195 return AffineMap::get(
196 /*numDims=*/0, /*numSymbols=*/0,
197 getAffineConstantExpr(0, shapedType.getContext()));
199 shapedType.getRank(),
200 getEffectiveVectorRankForXferOp(shapedType, vectorType),
201 shapedType.getContext());
202}
203
204/// Check if `write` is of a constant splat and the masked `read` is padded with
205/// the same splat value -- meaning it could be the same value as the initial
206/// constant splat.
207static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
208 vector::TransferReadOp read) {
209 auto readMask = read.getMask();
210 auto writeMask = write.getMask();
211 // Check if the masks are consistent. The splat value could be the same if the
212 // read is masked (and padded with the splat value), and the write is unmasked
213 // or has the same mask. Note this does not allow the case where the write is
214 // masked and the read is unmasked, as then the read could be of more elements
215 // than the write (which may not be the same value).
216 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
217 if (!couldBeSameSplat)
218 return false;
219 // Check for constant splat (as the source of the write).
220 DenseElementsAttr splatAttr;
221 if (!matchPattern(write.getVector(),
222 m_Constant<DenseElementsAttr>(&splatAttr)) ||
223 !splatAttr.isSplat()) {
224 return false;
225 }
226 // The padding of the read and the constant splat value must be the same.
227 Attribute padAttr;
228 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
229 return false;
230 return padAttr == splatAttr.getSplatValue<Attribute>();
231}
232
233bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
234 vector::TransferReadOp read) {
235 return !defWrite.hasOutOfBoundsDim() &&
236 defWrite.getIndices() == read.getIndices() &&
237 defWrite.getVectorType() == read.getVectorType() &&
238 defWrite.getPermutationMap() == read.getPermutationMap() &&
239 ((!defWrite.getMask() && !read.getMask()) ||
241}
242
243bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
244 vector::TransferWriteOp priorWrite) {
245 return priorWrite.getIndices() == write.getIndices() &&
246 priorWrite.getMask() == write.getMask() &&
247 priorWrite.getVectorType() == write.getVectorType() &&
248 priorWrite.getPermutationMap() == write.getPermutationMap();
249}
250
252 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
253 bool testDynamicValueUsingBounds) {
254 // For simplicity only look at transfer of same type.
255 if (transferA.getVectorType() != transferB.getVectorType())
256 return false;
257 unsigned rankOffset = transferA.getLeadingShapedRank();
258 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
259 Value indexA = transferA.getIndices()[i];
260 Value indexB = transferB.getIndices()[i];
261 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
262 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
263
264 if (i < rankOffset) {
265 // For leading dimensions, if we can prove that index are different we
266 // know we are accessing disjoint slices.
267 if (cstIndexA.has_value() && cstIndexB.has_value()) {
268 if (*cstIndexA != *cstIndexB)
269 return true;
270 continue;
271 }
272 if (testDynamicValueUsingBounds) {
273 // First try to see if we can fully compose and simplify the affine
274 // expression as a fast track.
275 FailureOr<uint64_t> delta =
277 if (succeeded(delta) && *delta != 0)
278 return true;
279
280 FailureOr<bool> testEqual =
282 if (succeeded(testEqual) && !testEqual.value())
283 return true;
284 }
285 } else {
286 // For this dimension, we slice a part of the memref we need to make sure
287 // the intervals accessed don't overlap.
288 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
289 if (cstIndexA.has_value() && cstIndexB.has_value()) {
290 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
291 if (distance >= vectorDim)
292 return true;
293 continue;
294 }
295 if (testDynamicValueUsingBounds) {
296 // First try to see if we can fully compose and simplify the affine
297 // expression as a fast track.
298 FailureOr<int64_t> delta =
300 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
301 return true;
302
303 FailureOr<int64_t> computeDelta =
305 if (succeeded(computeDelta)) {
306 if (std::abs(computeDelta.value()) >= vectorDim)
307 return true;
308 }
309 }
310 }
311 }
312 return false;
313}
314
315bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
316 VectorTransferOpInterface transferB,
317 bool testDynamicValueUsingBounds) {
318 if (transferA.getBase() != transferB.getBase())
319 return false;
320 return isDisjointTransferIndices(transferA, transferB,
321 testDynamicValueUsingBounds);
322}
323
324// Helper to iterate over n-D vector slice elements. Calculate the next
325// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
326// Modifies the `position` in place. Returns a failure when `position` becomes
327// the end position.
328static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
330 ArrayRef<int64_t> offsets) {
331 for (auto [posInDim, dimSize, offsetInDim] :
332 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
333 ++posInDim;
334 if (posInDim < dimSize + offsetInDim)
335 return success();
336
337 // Carry the overflow to the next loop iteration.
338 posInDim = offsetInDim;
339 }
340
341 return failure();
342}
343
344/// Returns the integer numbers in `values`. `values` are expected to be
345/// constant operations.
348 llvm::transform(values, std::back_inserter(ints), [](Value value) {
349 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
350 assert(constOp && "Unexpected non-constant index");
351 return constOp.value();
352 });
353 return ints;
354}
355
356/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
357/// be constant operations.
360 llvm::transform(
361 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
362 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
363 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
364 });
365 return ints;
366}
367
368/// Convert `foldResults` into Values. Integer attributes are converted to
369/// constant op.
371 ArrayRef<OpFoldResult> foldResults) {
372 SmallVector<Value> values;
373 llvm::transform(foldResults, std::back_inserter(values),
374 [&](OpFoldResult foldResult) {
375 if (auto attr = dyn_cast<Attribute>(foldResult))
377 builder, loc, cast<IntegerAttr>(attr).getInt())
378 .getResult();
379
380 return cast<Value>(foldResult);
381 });
382 return values;
383}
384
385std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
386 if (value.getDefiningOp<vector::VectorScaleOp>())
387 return 1;
388 auto mul = value.getDefiningOp<arith::MulIOp>();
389 if (!mul)
390 return {};
391 auto lhs = mul.getLhs();
392 auto rhs = mul.getRhs();
393 if (lhs.getDefiningOp<vector::VectorScaleOp>())
394 return getConstantIntValue(rhs);
395 if (rhs.getDefiningOp<vector::VectorScaleOp>())
396 return getConstantIntValue(lhs);
397 return {};
398}
399
400/// Converts numeric attributes to the expected type. Supports
401/// integer-to-integer and float-to-integer conversions. Returns the original
402/// attribute if no conversion is needed or supported.
403static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
404 // Integer-to-integer conversion
405 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
406 if (auto intType = dyn_cast<IntegerType>(expectedType)) {
407 if (intAttr.getType() != expectedType)
408 return IntegerAttr::get(expectedType, intAttr.getInt());
409 }
410 return attr;
411 }
412
413 // Float-to-integer bitcast (preserves bit representation)
414 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
415 auto intType = dyn_cast<IntegerType>(expectedType);
416 if (!intType)
417 return attr;
418
419 APFloat floatVal = floatAttr.getValue();
420 APInt intVal = floatVal.bitcastToAPInt();
421 return IntegerAttr::get(expectedType, intVal);
422 }
423
424 return attr;
425}
426
427//===----------------------------------------------------------------------===//
428// CombiningKindAttr
429//===----------------------------------------------------------------------===//
430
431namespace mlir {
432namespace vector {
433namespace detail {
435 using KeyTy = uint64_t;
436
438
439 bool operator==(const KeyTy &key) const { return value == key; }
440
442 const KeyTy &key) {
443 return new (allocator.allocate<BitmaskEnumStorage>())
445 }
446
448};
449} // namespace detail
450} // namespace vector
451} // namespace mlir
452
453//===----------------------------------------------------------------------===//
454// VectorDialect
455//===----------------------------------------------------------------------===//
456
457namespace {
458/// This class defines the interface for handling inlining with vector dialect
459/// operations.
460struct VectorInlinerInterface : public DialectInlinerInterface {
462
463 /// All vector dialect ops can be inlined.
464 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
465 return true;
466 }
467};
468} // namespace
469
470void VectorDialect::initialize() {
471 addAttributes<
472#define GET_ATTRDEF_LIST
473#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
474 >();
475
476 addOperations<
477#define GET_OP_LIST
478#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
479 >();
480
481 addInterfaces<VectorInlinerInterface>();
482
483 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
484 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
485 YieldOp>();
486 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
487 TransferWriteOp>();
488 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
489 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
490 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
491}
492
493/// Materialize a single constant operation from a given attribute value with
494/// the desired resultant type.
495Operation *VectorDialect::materializeConstant(OpBuilder &builder,
496 Attribute value, Type type,
497 Location loc) {
498 if (isa<ub::PoisonAttrInterface>(value))
499 return value.getDialect().materializeConstant(builder, value, type, loc);
500
501 return arith::ConstantOp::materialize(builder, value, type, loc);
502}
503
505 return builder.getIntegerType(64);
506}
507
509 ArrayRef<int64_t> values) {
510 return builder.getI64ArrayAttr(values);
511}
512
513//===----------------------------------------------------------------------===//
514// MultiDimReductionOp
515//===----------------------------------------------------------------------===//
516
517void vector::MultiDimReductionOp::build(OpBuilder &builder,
518 OperationState &result, Value source,
519 Value acc, ArrayRef<bool> reductionMask,
520 CombiningKind kind) {
521 SmallVector<int64_t> reductionDims;
522 for (const auto &en : llvm::enumerate(reductionMask))
523 if (en.value())
524 reductionDims.push_back(en.index());
525 build(builder, result, kind, source, acc, reductionDims);
526}
527
528OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
529 // Single parallel dim, this is a noop.
530 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
531 return getSource();
532 return {};
533}
534
535std::optional<SmallVector<int64_t, 4>>
536MultiDimReductionOp::getShapeForUnroll() {
537 return llvm::to_vector<4>(getSourceVectorType().getShape());
538}
539
540LogicalResult MultiDimReductionOp::verify() {
541 SmallVector<int64_t> targetShape;
542 SmallVector<bool> scalableDims;
543 Type inferredReturnType;
544 auto sourceScalableDims = getSourceVectorType().getScalableDims();
545 for (auto [dimIdx, dimSize] :
546 llvm::enumerate(getSourceVectorType().getShape()))
547 if (!llvm::any_of(getReductionDims(),
548 [dimIdx = dimIdx](int64_t reductionDimIdx) {
549 return reductionDimIdx == static_cast<int64_t>(dimIdx);
550 })) {
551 targetShape.push_back(dimSize);
552 scalableDims.push_back(sourceScalableDims[dimIdx]);
553 }
554 // TODO: update to also allow 0-d vectors when available.
555 if (targetShape.empty())
556 inferredReturnType = getSourceVectorType().getElementType();
557 else
558 inferredReturnType = VectorType::get(
559 targetShape, getSourceVectorType().getElementType(), scalableDims);
560 if (getType() != inferredReturnType)
561 return emitOpError() << "destination type " << getType()
562 << " is incompatible with source type "
563 << getSourceVectorType();
564
565 return success();
566}
567
568/// Returns the mask type expected by this operation.
569Type MultiDimReductionOp::getExpectedMaskType() {
570 auto vecType = getSourceVectorType();
571 return VectorType::get(vecType.getShape(),
572 IntegerType::get(vecType.getContext(), /*width=*/1),
573 vecType.getScalableDims());
574}
575
576namespace {
577// Only unit dimensions that are being reduced are folded. If the dimension is
578// unit, but not reduced, it is not folded, thereby keeping the output type the
579// same. If not all dimensions which are reduced are of unit dimension, this
580// transformation does nothing. This is just a generalization of
581// ElideSingleElementReduction for ReduceOp.
582struct ElideUnitDimsInMultiDimReduction
583 : public OpRewritePattern<MultiDimReductionOp> {
584 using Base::Base;
585
586 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
587 PatternRewriter &rewriter) const override {
588 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
589 for (const auto &dim : enumerate(shape)) {
590 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
591 return failure();
592 }
593
594 // Vector mask setup.
595 OpBuilder::InsertionGuard guard(rewriter);
596 Operation *rootOp;
597 Value mask;
598 if (reductionOp.isMasked()) {
599 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
600 rootOp = reductionOp.getMaskingOp();
601 mask = reductionOp.getMaskingOp().getMask();
602 } else {
603 rootOp = reductionOp;
604 }
605
606 Location loc = reductionOp.getLoc();
607 Value acc = reductionOp.getAcc();
608 Value cast;
609 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
610 if (mask) {
611 VectorType newMaskType =
612 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
613 dstVecType.getScalableDims());
614 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
615 }
616 cast = vector::ShapeCastOp::create(
617 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
618 } else {
619 // This means we are reducing all the dimensions, and all reduction
620 // dimensions are of size 1. So a simple extraction would do.
621 if (mask)
622 mask = vector::ExtractOp::create(rewriter, loc, mask);
623 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
624 }
625
626 Value result =
627 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
628 cast, /*fastmath=*/nullptr, mask);
629 rewriter.replaceOp(rootOp, result);
630 return success();
631 }
632};
633} // namespace
634
635void MultiDimReductionOp::getCanonicalizationPatterns(
636 RewritePatternSet &results, MLIRContext *context) {
637 results.add<ElideUnitDimsInMultiDimReduction>(context);
638}
639
640//===----------------------------------------------------------------------===//
641// ReductionOp
642//===----------------------------------------------------------------------===//
643
644void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
645 CombiningKind kind, Value vector,
646 arith::FastMathFlags fastMathFlags) {
647 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
648}
649
650void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
651 CombiningKind kind, Value vector, Value acc,
652 arith::FastMathFlags fastMathFlags) {
653 build(builder, result,
654 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
655 acc, fastMathFlags);
656}
657
658LogicalResult ReductionOp::verify() {
659 // Verify for 0-D and 1-D vector.
660 int64_t rank = getSourceVectorType().getRank();
661 if (rank > 1)
662 return emitOpError("unsupported reduction rank: ") << rank;
663
664 // Verify supported reduction kind.
665 Type eltType = getDest().getType();
666 if (!isSupportedCombiningKind(getKind(), eltType))
667 return emitOpError("unsupported reduction type '")
668 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
669 << "'";
670
671 return success();
672}
673
674// MaskableOpInterface methods.
675
676/// Returns the mask type expected by this operation.
677Type ReductionOp::getExpectedMaskType() {
678 auto vecType = getSourceVectorType();
679 return VectorType::get(vecType.getShape(),
680 IntegerType::get(vecType.getContext(), /*width=*/1),
681 vecType.getScalableDims());
682}
683
685 OpBuilder &builder, Location loc,
686 Value vector) {
687 switch (op) {
688 case arith::AtomicRMWKind::addf:
689 case arith::AtomicRMWKind::addi:
690 return vector::ReductionOp::create(builder, vector.getLoc(),
691 CombiningKind::ADD, vector);
692 case arith::AtomicRMWKind::mulf:
693 case arith::AtomicRMWKind::muli:
694 return vector::ReductionOp::create(builder, vector.getLoc(),
695 CombiningKind::MUL, vector);
696 case arith::AtomicRMWKind::minimumf:
697 return vector::ReductionOp::create(builder, vector.getLoc(),
698 CombiningKind::MINIMUMF, vector);
699 case arith::AtomicRMWKind::mins:
700 return vector::ReductionOp::create(builder, vector.getLoc(),
701 CombiningKind::MINSI, vector);
702 case arith::AtomicRMWKind::minu:
703 return vector::ReductionOp::create(builder, vector.getLoc(),
704 CombiningKind::MINUI, vector);
705 case arith::AtomicRMWKind::maximumf:
706 return vector::ReductionOp::create(builder, vector.getLoc(),
707 CombiningKind::MAXIMUMF, vector);
708 case arith::AtomicRMWKind::maxs:
709 return vector::ReductionOp::create(builder, vector.getLoc(),
710 CombiningKind::MAXSI, vector);
711 case arith::AtomicRMWKind::maxu:
712 return vector::ReductionOp::create(builder, vector.getLoc(),
713 CombiningKind::MAXUI, vector);
714 case arith::AtomicRMWKind::andi:
715 return vector::ReductionOp::create(builder, vector.getLoc(),
716 CombiningKind::AND, vector);
717 case arith::AtomicRMWKind::ori:
718 return vector::ReductionOp::create(builder, vector.getLoc(),
719 CombiningKind::OR, vector);
720 case arith::AtomicRMWKind::minnumf:
721 return vector::ReductionOp::create(builder, vector.getLoc(),
722 CombiningKind::MINNUMF, vector);
723 case arith::AtomicRMWKind::maxnumf:
724 return vector::ReductionOp::create(builder, vector.getLoc(),
725 CombiningKind::MAXNUMF, vector);
726 case arith::AtomicRMWKind::xori:
727 return vector::ReductionOp::create(builder, vector.getLoc(),
728 CombiningKind::XOR, vector);
729 default:
730 (void)emitOptionalError(loc, "Reduction operation type not supported");
731 break;
732 }
733 return nullptr;
734}
735
736std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
737 return llvm::to_vector<4>(getSourceVectorType().getShape());
738}
739
740namespace {
741struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
742 using Base::Base;
743
744 LogicalResult matchAndRewrite(ReductionOp reductionOp,
745 PatternRewriter &rewriter) const override {
746 // Vector mask setup.
747 OpBuilder::InsertionGuard guard(rewriter);
748 auto maskableOp =
749 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
750 Operation *rootOp;
751 Value mask;
752 if (maskableOp.isMasked()) {
753 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
754 rootOp = maskableOp.getMaskingOp();
755 mask = maskableOp.getMaskingOp().getMask();
756 } else {
757 rootOp = reductionOp;
758 }
759
760 auto vectorType = reductionOp.getSourceVectorType();
761 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
762 return failure();
763
764 Location loc = reductionOp.getLoc();
765 if (mask)
766 mask = ExtractOp::create(rewriter, loc, mask);
767 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
768
769 if (Value acc = reductionOp.getAcc())
770 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
771 result, acc,
772 reductionOp.getFastmathAttr(), mask);
773
774 rewriter.replaceOp(rootOp, result);
775 return success();
776 }
777};
778} // namespace
779
780void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
781 MLIRContext *context) {
782 results.add<ElideSingleElementReduction>(context);
783}
784
785//===----------------------------------------------------------------------===//
786// ContractionOp
787//===----------------------------------------------------------------------===//
788
789void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
791 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
792 ArrayRef<IteratorType> iteratorTypes) {
793 result.addOperands({lhs, rhs, acc});
794 result.addTypes(acc.getType());
795 result.addAttribute(
796 getIndexingMapsAttrName(result.name),
797 builder.getAffineMapArrayAttr(
798 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
799 result.addAttribute(
800 getIteratorTypesAttrName(result.name),
801 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
802 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
803 return IteratorTypeAttr::get(builder.getContext(), t);
804 }))));
805}
806
807void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
809 ArrayAttr indexingMaps,
810 ArrayAttr iteratorTypes) {
811 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
812 ContractionOp::getDefaultKind());
813}
814
815void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
817 ArrayAttr indexingMaps,
818 ArrayAttr iteratorTypes, CombiningKind kind) {
819 result.addOperands({lhs, rhs, acc});
820 result.addTypes(acc.getType());
821 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
822 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
823 result.addAttribute(getKindAttrName(result.name),
824 CombiningKindAttr::get(builder.getContext(), kind));
825}
826
827ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
833 Type resultType;
834 auto loc = parser.getCurrentLocation();
835 DictionaryAttr dictAttr;
836 // TODO: Unify linalg op attribute parsing.
837 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
838 parser.parseComma() || parser.parseOperand(rhsInfo) ||
839 parser.parseComma() || parser.parseOperand(accInfo) ||
840 parser.parseTrailingOperandList(masksInfo) ||
841 parser.parseOptionalAttrDict(result.attributes) ||
842 parser.parseColonTypeList(types) ||
843 parser.parseKeywordType("into", resultType) ||
844 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
845 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
846 parser.resolveOperand(accInfo, resultType, result.operands) ||
847 parser.addTypeToList(resultType, result.types))
848 return failure();
849 result.attributes.append(dictAttr.getValue().begin(),
850 dictAttr.getValue().end());
851
852 // Convert array of string into an array of IteratyType enums. This is needed,
853 // because tests still use the old format when 'iterator_types' attribute is
854 // represented as an array of strings.
855 // TODO: Remove this conversion once tests are fixed.
856 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
857 result.attributes.get(getIteratorTypesAttrName(result.name)));
858 if (!iteratorTypes) {
859 return parser.emitError(loc)
860 << "expected " << getIteratorTypesAttrName(result.name)
861 << " array attribute";
862 }
863
864 SmallVector<Attribute> iteratorTypeAttrs;
865
866 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
867 auto maybeIteratorType = symbolizeIteratorType(s);
868 if (!maybeIteratorType.has_value())
869 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
870
871 iteratorTypeAttrs.push_back(
872 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
873 }
874 result.attributes.set(getIteratorTypesAttrName(result.name),
875 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
876
877 if (!result.attributes.get(getKindAttrName(result.name))) {
878 result.addAttribute(
879 getKindAttrName(result.name),
880 CombiningKindAttr::get(result.getContext(),
881 ContractionOp::getDefaultKind()));
882 }
883 if (masksInfo.empty())
884 return success();
885 if (masksInfo.size() != 2)
886 return parser.emitError(parser.getNameLoc(),
887 "expected zero or exactly 2 vector mask operands");
888 auto lhsType = llvm::cast<VectorType>(types[0]);
889 auto rhsType = llvm::cast<VectorType>(types[1]);
890 auto maskElementType = parser.getBuilder().getI1Type();
891 std::array<VectorType, 2> maskTypes = {
892 VectorType::Builder(lhsType).setElementType(maskElementType),
893 VectorType::Builder(rhsType).setElementType(maskElementType)};
894 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
895 return failure();
896 return success();
897}
898
899void ContractionOp::print(OpAsmPrinter &p) {
900 // TODO: Unify printing code with linalg ops.
901 auto attrNames = getTraitAttrNames();
902 llvm::StringSet<> traitAttrsSet;
903 traitAttrsSet.insert_range(attrNames);
905 for (auto attr : (*this)->getAttrs()) {
906 if (attr.getName() == getIteratorTypesAttrName()) {
907 auto iteratorTypes =
908 llvm::cast<ArrayAttr>(attr.getValue())
909 .getAsValueRange<IteratorTypeAttr, IteratorType>();
910 // Convert IteratorType enums into the string representation. This is
911 // needed, because tests still use the old format when 'iterator_types'
912 // attribute is represented as an array of strings.
913 // TODO: Remove this conversion once tests are fixed.
914 SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
915 llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
916 return StringAttr::get(getContext(), stringifyIteratorType(t));
917 }));
918
919 attrs.emplace_back(getIteratorTypesAttrName(),
920 ArrayAttr::get(getContext(), iteratorTypeNames));
921 } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
922 attrs.push_back(attr);
923 }
924
925 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
926 p << " " << dictAttr << " " << getLhs() << ", ";
927 p << getRhs() << ", " << getAcc();
928
929 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
930 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
931 << getResultType();
932}
933
934static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
935 const std::vector<std::pair<int64_t, int64_t>> &map) {
936 for (auto &dimPair : map) {
937 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
938 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
939 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
940 return false;
941 }
942 return true;
943}
944
945static LogicalResult verifyOutputShape(
946 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
947 Type resType,
948 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
949 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
950 DenseSet<int64_t> lhsContractingDimSet;
951 DenseSet<int64_t> rhsContractingDimSet;
952 for (auto &dimPair : contractingDimMap) {
953 lhsContractingDimSet.insert(dimPair.first);
954 rhsContractingDimSet.insert(dimPair.second);
955 }
956 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
957 llvm::make_second_range(batchDimMap));
958
959 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
960 SmallVector<int64_t, 4> expectedResultDims;
961 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
962 if (lhsContractingDimSet.count(i) > 0)
963 continue;
964 expectedResultDims.push_back(lhsType.getDimSize(i));
965 }
966
967 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
968 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
969 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
970 continue;
971 expectedResultDims.push_back(rhsType.getDimSize(i));
972 }
973
974 // Verify 'expectedResultDims'.
975 if (expectedResultDims.empty()) {
976 // No batch or free dimension implies a scalar result.
977 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
978 return op.emitOpError("invalid accumulator/result vector shape");
979 } else {
980 // At least one batch or free dimension implies a vector result.
981 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
982 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
983 if (!resVectorType || !accVectorType)
984 return op.emitOpError("invalid accumulator/result vector shape");
985
986 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
987 // types fully define the result vector type. This assumes the affine maps
988 // are well-formed, which must have been verified already.
989 MLIRContext *ctx = op.getContext();
990 AffineMap lhsMap = op.getIndexingMapsArray()[0];
991 AffineMap rhsMap = op.getIndexingMapsArray()[1];
992 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
993 return op.emitOpError(
994 "expected all dimensions to be either a LHS or a RHS dimension");
996 for (auto pair :
997 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
998 VectorType v = pair.first;
999 auto map = pair.second;
1000 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1001 unsigned pos = map.getDimPosition(idx);
1002 if (!extents[pos])
1003 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
1004 }
1005 }
1006 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
1007 return op.emitOpError("expected all dimensions to get an extent as "
1008 "either a LHS or a RHS dimension");
1009
1010 AffineMap resMap = op.getIndexingMapsArray()[2];
1011 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1012 /*symbolCount=*/0, extents, ctx);
1013 // Compose the resMap with the extentsMap, which is a constant map.
1014 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1015 assert(llvm::all_of(expectedMap.getResults(),
1016 llvm::IsaPred<AffineConstantExpr>) &&
1017 "expected constant extent along all dimensions.");
1018 // Extract the expected shape and build the type.
1019 auto expectedShape = llvm::to_vector<4>(
1020 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
1021 return cast<AffineConstantExpr>(e).getValue();
1022 }));
1023 auto expected =
1024 VectorType::get(expectedShape, resVectorType.getElementType(),
1025 resVectorType.getScalableDims());
1026 if (resVectorType != expected || accVectorType != expected)
1027 return op.emitOpError(
1028 "invalid accumulator/result vector shape, expected: ")
1029 << expected;
1030 }
1031 return success();
1032}
1033
1034LogicalResult ContractionOp::verify() {
1035 VectorType lhsType = getLhsType();
1036 VectorType rhsType = getRhsType();
1037 Type accType = getAccType();
1038 Type resType = getResultType();
1039
1040 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1041 if (!lhsType.getElementType().isSignlessInteger())
1042 return emitOpError("only supports signless integer types");
1043 }
1044
1045 // Verify that an indexing map was specified for each vector operand.
1046 if (getIndexingMapsArray().size() != 3)
1047 return emitOpError("expected an indexing map for each vector operand");
1048
1049 // Verify that each index map has 'numIterators' inputs, no symbols, and
1050 // that the number of map outputs equals the rank of its associated
1051 // vector operand.
1052 unsigned numIterators = getIteratorTypes().getValue().size();
1053 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1054 auto index = it.index();
1055 auto map = it.value();
1056 if (map.getNumSymbols() != 0)
1057 return emitOpError("expected indexing map ")
1058 << index << " to have no symbols";
1059 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1060 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1061 // Verify that the map has the right number of inputs, outputs, and indices.
1062 // This also correctly accounts for (..) -> () for rank-0 results.
1063 if (map.getNumDims() != numIterators)
1064 return emitOpError("expected indexing map ")
1065 << index << " to have " << numIterators << " number of inputs";
1066 if (map.getNumResults() != rank)
1067 return emitOpError("expected indexing map ")
1068 << index << " to have " << rank << " number of outputs";
1069 if (!map.isProjectedPermutation())
1070 return emitOpError("expected indexing map ")
1071 << index << " to be a projected permutation of its inputs";
1072 }
1073
1074 auto contractingDimMap = getContractingDimMap();
1075 auto batchDimMap = getBatchDimMap();
1076
1077 // Verify at least one contracting dimension pair was specified.
1078 if (contractingDimMap.empty())
1079 return emitOpError("expected at least one contracting dimension pair");
1080
1081 // Verify contracting dimension map was properly constructed.
1082 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1083 return emitOpError("invalid contracting dimension map");
1084
1085 // Verify batch dimension map was properly constructed.
1086 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1087 return emitOpError("invalid batch dimension map");
1088
1089 // Verify 'accType' and 'resType' shape.
1090 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1091 contractingDimMap, batchDimMap)))
1092 return failure();
1093
1094 // Verify supported combining kind.
1095 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1096 auto elementType = vectorType ? vectorType.getElementType() : resType;
1097 if (!isSupportedCombiningKind(getKind(), elementType))
1098 return emitOpError("unsupported contraction type");
1099
1100 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1101 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1102}
1103
1104// MaskableOpInterface methods.
1105
1106/// Returns the mask type expected by this operation. Mostly used for
1107/// verification purposes. It requires the operation to be vectorized."
1108Type ContractionOp::getExpectedMaskType() {
1109 auto indexingMaps = this->getIndexingMapsArray();
1110 AffineMap lhsIdxMap = indexingMaps[0];
1111 AffineMap rhsIdxMap = indexingMaps[1];
1112 VectorType lhsType = this->getLhsType();
1113 VectorType rhsType = this->getRhsType();
1114
1115 unsigned numVecDims = lhsIdxMap.getNumDims();
1116 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1117 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1118
1119 // Using the information in the indexing maps, extract the size of each
1120 // dimension in the vector.contract operation from the two input operands.
1121 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1122 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1123 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1124 lhsType.getScalableDims()[dimIdx];
1125 }
1126 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1127 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1128 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1129 rhsType.getScalableDims()[dimIdx];
1130 }
1131
1132 assert(ShapedType::isStaticShape(maskShape) &&
1133 "Mask shape couldn't be computed");
1134
1135 return VectorType::get(maskShape,
1136 IntegerType::get(lhsType.getContext(), /*width=*/1),
1137 maskShapeScalableDims);
1138}
1139
1140SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1141 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1142 getIteratorTypesAttrName(), getKindAttrName()};
1143}
1144
1146 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1147 if (targetExpr == map.getResult(i))
1148 return i;
1149 return -1;
1150}
1151
1152static std::vector<std::pair<int64_t, int64_t>>
1153getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1154 IteratorType targetIteratorType, MLIRContext *context) {
1155 std::vector<std::pair<int64_t, int64_t>> dimMap;
1156 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1157 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1158 if (iteratorType != targetIteratorType)
1159 continue;
1160 // Search lhs/rhs map results for 'targetExpr'.
1161 auto targetExpr = getAffineDimExpr(it.index(), context);
1162 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1163 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1164 if (lhsDim >= 0 && rhsDim >= 0)
1165 dimMap.emplace_back(lhsDim, rhsDim);
1166 }
1167 return dimMap;
1168}
1169
1170void ContractionOp::getIterationBounds(
1171 SmallVectorImpl<int64_t> &iterationBounds) {
1172 auto lhsShape = getLhsType().getShape();
1173 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1174 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1175 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1176 // Search lhs/rhs map results for 'targetExpr'.
1177 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1178 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1179 if (iteratorType == IteratorType::reduction) {
1180 // Get reduction dim size from lhs shape (same size in rhsShape).
1181 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1182 assert(lhsDimIndex >= 0);
1183 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1184 continue;
1185 }
1186 // Get parallel dimension size from result shape.
1187 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1188 assert(resDimIndex >= 0);
1189 assert(resVectorType != nullptr);
1190 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1191 }
1192}
1193
1194void ContractionOp::getIterationIndexMap(
1195 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1196 unsigned numMaps = getIndexingMapsArray().size();
1197 iterationIndexMap.resize(numMaps);
1198 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1199 auto index = it.index();
1200 auto map = it.value();
1201 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1202 auto dim = cast<AffineDimExpr>(map.getResult(i));
1203 iterationIndexMap[index][dim.getPosition()] = i;
1204 }
1205 }
1206}
1207
1208std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1209 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1210 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1211 getContext());
1212}
1213
1214std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1215 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1216 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1217 getContext());
1218}
1219
1220std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1222 getIterationBounds(shape);
1223 return shape;
1224}
1225
1226/// Return a fused vector::ContractionOp which represents a patterns such as:
1227///
1228/// ```mlir
1229/// %c0 = vector.constant 0: ...
1230/// %c = vector.contract %a, %b, %c0: ...
1231/// %e = add %c, %d: ...
1232/// ```
1233///
1234/// by:
1235///
1236/// ```mlir
1237/// %e = vector.contract %a, %b, %d: ...
1238/// ```
1239///
1240/// Return null if the canonicalization does not apply.
1241// TODO: This should be a folding of Add into Contract in core but while they
1242// live in different dialects, it is not possible without unnatural
1243// dependencies.
1244template <typename AddOpType>
1245struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1246 using OpRewritePattern<AddOpType>::OpRewritePattern;
1247
1248 LogicalResult matchAndRewrite(AddOpType addOp,
1249 PatternRewriter &rewriter) const override {
1250 auto canonicalize = [&](Value maybeContraction,
1251 Value otherOperand) -> vector::ContractionOp {
1252 vector::ContractionOp contractionOp =
1253 dyn_cast_or_null<vector::ContractionOp>(
1254 maybeContraction.getDefiningOp());
1255 if (!contractionOp)
1256 return vector::ContractionOp();
1257 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1258 contractionOp.getAcc().getDefiningOp())) {
1259 if (maybeZero.getValue() ==
1260 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1261 IRMapping bvm;
1262 bvm.map(contractionOp.getAcc(), otherOperand);
1263 auto newContraction =
1264 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1265 rewriter.replaceOp(addOp, newContraction.getResult());
1266 return newContraction;
1267 }
1268 }
1269 return vector::ContractionOp();
1270 };
1271
1272 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1273 vector::ContractionOp contract = canonicalize(a, b);
1274 contract = contract ? contract : canonicalize(b, a);
1275 return contract ? success() : failure();
1276 }
1277};
1278
1279void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1280 MLIRContext *context) {
1283}
1284
1285// Returns `true` if `index` is either within [0, maxIndex) or equal to
1286// `poisonValue`.
1288 int64_t maxIndex) {
1289 return index == poisonValue || (index >= 0 && index < maxIndex);
1290}
1291
1292//===----------------------------------------------------------------------===//
1293// ExtractOp
1294//===----------------------------------------------------------------------===//
1295
1296void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1297 SetIntRangeFn setResultRanges) {
1298 setResultRanges(getResult(), argRanges.front());
1299}
1300
1301void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1302 Value source) {
1303 auto vectorTy = cast<VectorType>(source.getType());
1304 build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1305}
1306
1307void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1308 Value source, int64_t position) {
1309 build(builder, result, source, ArrayRef<int64_t>{position});
1310}
1311
1312void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1313 Value source, OpFoldResult position) {
1314 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1315}
1316
1317void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1318 Value source, ArrayRef<int64_t> position) {
1319 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1320 builder.getDenseI64ArrayAttr(position));
1321}
1322
1323void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1324 Value source, ArrayRef<OpFoldResult> position) {
1325 SmallVector<int64_t> staticPos;
1326 SmallVector<Value> dynamicPos;
1327 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1328 build(builder, result, source, dynamicPos,
1329 builder.getDenseI64ArrayAttr(staticPos));
1330}
1331
1332LogicalResult
1333ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1334 ExtractOp::Adaptor adaptor,
1335 SmallVectorImpl<Type> &inferredReturnTypes) {
1336 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1337 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1338 vectorType.getRank()) {
1339 inferredReturnTypes.push_back(vectorType.getElementType());
1340 } else {
1341 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1342 vectorType.getRank());
1343 inferredReturnTypes.push_back(VectorType::get(
1344 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1345 vectorType.getScalableDims().drop_front(n)));
1346 }
1347 return success();
1348}
1349
1350bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1351 // Allow extracting 1-element vectors instead of scalars.
1352 auto isCompatible = [](TypeRange l, TypeRange r) {
1353 auto vectorType = llvm::dyn_cast<VectorType>(l.front());
1354 return vectorType && vectorType.getShape().equals({1}) &&
1355 vectorType.getElementType() == r.front();
1356 };
1357 if (l.size() == 1 && r.size() == 1 &&
1358 (isCompatible(l, r) || isCompatible(r, l)))
1359 return true;
1360 return l == r;
1361}
1362
1363LogicalResult vector::ExtractOp::verify() {
1364 if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1365 if (resTy.getRank() == 0)
1366 return emitError(
1367 "expected a scalar instead of a 0-d vector as the result type");
1368
1369 // Note: This check must come before getMixedPosition() to prevent a crash.
1370 auto dynamicMarkersCount =
1371 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1372 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1373 return emitOpError(
1374 "mismatch between dynamic and static positions (kDynamic marker but no "
1375 "corresponding dynamic position) -- this can only happen due to an "
1376 "incorrect fold/rewrite");
1377 auto position = getMixedPosition();
1378 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1379 return emitOpError(
1380 "expected position attribute of rank no greater than vector rank");
1381 for (auto [idx, pos] : llvm::enumerate(position)) {
1382 if (auto attr = dyn_cast<Attribute>(pos)) {
1383 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1385 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1386 return emitOpError("expected position attribute #")
1387 << (idx + 1)
1388 << " to be a non-negative integer smaller than the "
1389 "corresponding vector dimension or poison (-1)";
1390 }
1391 }
1392 }
1393 return success();
1394}
1395
1396template <typename IntType>
1398 return llvm::to_vector<4>(llvm::map_range(
1399 arrayAttr.getAsRange<IntegerAttr>(),
1400 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1401}
1402
1403/// Fold the result of chains of ExtractOp in place by simply concatenating the
1404/// positions.
1405static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1406 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1407 return failure();
1408
1409 // TODO: Canonicalization for dynamic position not implemented yet.
1410 if (extractOp.hasDynamicPosition())
1411 return failure();
1412
1413 SmallVector<int64_t> globalPosition;
1414 ExtractOp currentOp = extractOp;
1415 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1416 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1417 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1418 currentOp = nextOp;
1419 // TODO: Canonicalization for dynamic position not implemented yet.
1420 if (currentOp.hasDynamicPosition())
1421 return failure();
1422 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1423 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1424 }
1425 extractOp.setOperand(0, currentOp.getSource());
1426 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1427 OpBuilder b(extractOp.getContext());
1428 std::reverse(globalPosition.begin(), globalPosition.end());
1429 extractOp.setStaticPosition(globalPosition);
1430 return success();
1431}
1432
1433namespace {
1434/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1435/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1436/// Compose TransposeOp permutations as we walk back.
1437/// This helper class keeps an updated extraction position `extractPosition`
1438/// with extra trailing sentinels.
1439/// The sentinels encode the internal transposition status of the result vector.
1440/// As we iterate, extractPosition is permuted and updated.
1441class ExtractFromInsertTransposeChainState {
1442public:
1443 ExtractFromInsertTransposeChainState(ExtractOp e);
1444
1445 /// Iterate over producing insert and transpose ops until we find a fold.
1446 Value fold();
1447
1448private:
1449 /// Return true if the vector at position `a` is contained within the vector
1450 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1451 /// is a prefix of `b`.
1452 template <typename ContainerA, typename ContainerB>
1453 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1454 return a.size() <= b.size() &&
1455 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1456 }
1457
1458 /// Return true if the vector at position `a` intersects the vector at
1459 /// position `b`. Under insert/extract semantics, this is the same as equality
1460 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1461 /// Comparison is on the common prefix (i.e. zip).
1462 template <typename ContainerA, typename ContainerB>
1463 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1464 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1465 if (elemA < 0 || elemB < 0)
1466 continue;
1467 if (elemA != elemB)
1468 return false;
1469 }
1470 return true;
1471 }
1472
1473 /// Folding is only possible in the absence of an internal permutation in the
1474 /// result vector.
1475 bool canFold() {
1476 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1477 }
1478
1479 // Helper to get the next defining op of interest.
1480 void updateStateForNextIteration(Value v) {
1481 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1482 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1483 };
1484
1485 // Case 1. If we hit a transpose, just compose the map and iterate.
1486 // Invariant: insert + transpose do not change rank, we can always compose.
1487 LogicalResult handleTransposeOp();
1488
1489 // Case 2: the insert position matches extractPosition exactly, early return.
1490 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1491
1492 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1493 /// portion of the source of the insert.
1494 /// Example:
1495 /// ```
1496 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1497 /// // extractPosition == [1, 2, 3]
1498 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1499 /// // can fold to vector.extract %source[0, 3]
1500 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1501 /// ```
1502 /// To traverse through %source, we need to set the leading dims to 0 and
1503 /// drop the extra leading dims.
1504 /// This method updates the internal state.
1505 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1506
1507 /// Try to fold in place to extract(source, extractPosition) and return the
1508 /// folded result. Return null if folding is not possible (e.g. due to an
1509 /// internal transposition in the result).
1510 Value tryToFoldExtractOpInPlace(Value source);
1511
1512 ExtractOp extractOp;
1513 int64_t vectorRank;
1514 int64_t extractedRank;
1515
1516 InsertOp nextInsertOp;
1517 TransposeOp nextTransposeOp;
1518
1519 /// Sentinel values that encode the internal permutation status of the result.
1520 /// They are set to (-1, ... , -k) at the beginning and appended to
1521 /// `extractPosition`.
1522 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1523 /// ensure that there is no internal transposition.
1524 /// Internal transposition cannot be accounted for with a folding pattern.
1525 // TODO: We could relax the internal transposition with an extra transposition
1526 // operation in a future canonicalizer.
1527 SmallVector<int64_t> sentinels;
1528 SmallVector<int64_t> extractPosition;
1529};
1530} // namespace
1531
1532ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1533 ExtractOp e)
1534 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1535 extractedRank(extractOp.getNumIndices()) {
1536 assert(vectorRank >= extractedRank && "Extracted position overflow");
1537 sentinels.reserve(vectorRank - extractedRank);
1538 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1539 sentinels.push_back(-(i + 1));
1540 extractPosition.assign(extractOp.getStaticPosition().begin(),
1541 extractOp.getStaticPosition().end());
1542 llvm::append_range(extractPosition, sentinels);
1543}
1544
1545// Case 1. If we hit a transpose, just compose the map and iterate.
1546// Invariant: insert + transpose do not change rank, we can always compose.
1547LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1548 // TODO: Canonicalization for dynamic position not implemented yet.
1549 if (extractOp.hasDynamicPosition())
1550 return failure();
1551
1552 if (!nextTransposeOp)
1553 return failure();
1555 nextTransposeOp.getPermutation(), extractOp.getContext()));
1557 return success();
1558}
1559
1560// Case 2: the insert position matches extractPosition exactly, early return.
1561LogicalResult
1562ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1563 Value &res) {
1564 // TODO: Canonicalization for dynamic position not implemented yet.
1565 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1566 return failure();
1567
1568 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1569 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1570 return failure();
1571 // Case 2.a. early-exit fold.
1572 res = nextInsertOp.getValueToStore();
1573 // Case 2.b. if internal transposition is present, canFold will be false.
1574 return success(canFold());
1575}
1576
1577/// Case 3: if inserted position is a prefix of extractPosition,
1578/// extract a portion of the source of the insertion.
1579/// This method updates the internal state.
1580LogicalResult
1581ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1582 // TODO: Canonicalization for dynamic position not implemented yet.
1583 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1584 return failure();
1585
1586 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1587 if (!isContainedWithin(insertedPos, extractPosition))
1588 return failure();
1589 // Set leading dims to zero.
1590 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1591 // Drop extra leading dims.
1592 extractPosition.erase(extractPosition.begin(),
1593 extractPosition.begin() + insertedPos.size());
1594 extractedRank = extractPosition.size() - sentinels.size();
1595 // Case 3.a. early-exit fold (break and delegate to post-while path).
1596 res = nextInsertOp.getValueToStore();
1597 // Case 3.b. if internal transposition is present, canFold will be false.
1598 return success();
1599}
1600
1601/// Try to fold in place to extract(source, extractPosition) and return the
1602/// folded result. Return null if folding is not possible (e.g. due to an
1603/// internal transposition in the result).
1604Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1605 Value source) {
1606 // TODO: Canonicalization for dynamic position not implemented yet.
1607 if (extractOp.hasDynamicPosition())
1608 return Value();
1609
1610 // If we can't fold (either internal transposition, or nothing to fold), bail.
1611 bool nothingToFold = (source == extractOp.getSource());
1612 if (nothingToFold || !canFold())
1613 return Value();
1614
1615 // Otherwise, fold by updating the op inplace and return its result.
1616 OpBuilder b(extractOp.getContext());
1617 extractOp.setStaticPosition(
1618 ArrayRef(extractPosition).take_front(extractedRank));
1619 extractOp.getSourceMutable().assign(source);
1620 return extractOp.getResult();
1621}
1622
1623/// Iterate over producing insert and transpose ops until we find a fold.
1624Value ExtractFromInsertTransposeChainState::fold() {
1625 // TODO: Canonicalization for dynamic position not implemented yet.
1626 if (extractOp.hasDynamicPosition())
1627 return Value();
1628
1629 Value valueToExtractFrom = extractOp.getSource();
1630 updateStateForNextIteration(valueToExtractFrom);
1631 while (nextInsertOp || nextTransposeOp) {
1632 // Case 1. If we hit a transpose, just compose the map and iterate.
1633 // Invariant: insert + transpose do not change rank, we can always compose.
1634 if (succeeded(handleTransposeOp())) {
1635 valueToExtractFrom = nextTransposeOp.getVector();
1636 updateStateForNextIteration(valueToExtractFrom);
1637 continue;
1638 }
1639
1640 Value result;
1641 // Case 2: the position match exactly.
1642 if (succeeded(handleInsertOpWithMatchingPos(result)))
1643 return result;
1644
1645 // Case 3: if the inserted position is a prefix of extractPosition, we can
1646 // just extract a portion of the source of the insert.
1647 if (succeeded(handleInsertOpWithPrefixPos(result)))
1648 return tryToFoldExtractOpInPlace(result);
1649
1650 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1651 // values. This is a more difficult case and we bail.
1652 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1653 if (isContainedWithin(extractPosition, insertedPos) ||
1654 intersectsWhereNonNegative(extractPosition, insertedPos))
1655 return Value();
1656
1657 // Case 5: No intersection, we forward the extract to insertOp.dest().
1658 valueToExtractFrom = nextInsertOp.getDest();
1659 updateStateForNextIteration(valueToExtractFrom);
1660 }
1661 // If after all this we can fold, go for it.
1662 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1663}
1664
1665/// Returns true if the operation has a 0-D vector type operand or result.
1667 auto hasZeroDimVectorType = [](Type type) -> bool {
1668 auto vecType = dyn_cast<VectorType>(type);
1669 return vecType && vecType.getRank() == 0;
1670 };
1671
1672 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1673 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1674}
1675
1676/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1677/// considered to be 'broadcastlike'.
1678static bool isBroadcastLike(Operation *op) {
1679 if (isa<BroadcastOp>(op))
1680 return true;
1681
1682 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1683 if (!shapeCast)
1684 return false;
1685
1686 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1687 // Checking that the destination shape has a prefix of 1s is not sufficient,
1688 // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1689 // is that the source shape is a suffix of the destination shape.
1690 VectorType srcType = shapeCast.getSourceVectorType();
1691 ArrayRef<int64_t> srcShape = srcType.getShape();
1692 uint64_t srcRank = srcType.getRank();
1693 ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1694 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1695}
1696
1697/// Fold extract(broadcast(X)) to either extract(X) or just X.
1698///
1699/// Example:
1700///
1701/// broadcast extract [1][2]
1702/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1703///
1704/// becomes
1705/// extract [1]
1706/// (3,4) -------------------------------------> (4)
1707///
1708///
1709/// The variable names used in this implementation correspond to the above
1710/// shapes as,
1711///
1712/// - (3, 4) is `input` shape.
1713/// - (2, 3, 4) is `broadcast` shape.
1714/// - (4) is `extract` shape.
1715///
1716/// This folding is possible when the suffix of `input` shape is the same as
1717/// `extract` shape.
1718static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1719
1720 Operation *defOp = extractOp.getSource().getDefiningOp();
1721 if (!defOp || !isBroadcastLike(defOp))
1722 return Value();
1723
1724 Value input = defOp->getOperand(0);
1725
1726 // Replace extract(broadcast(X)) with X
1727 if (extractOp.getType() == input.getType())
1728 return input;
1729
1730 // Get required types and ranks in the chain
1731 // input -> broadcast -> extract
1732 // (scalars are treated as rank-0).
1733 auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1734 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1735 unsigned inputRank = inputType ? inputType.getRank() : 0;
1736 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1737 unsigned extractRank = extractType ? extractType.getRank() : 0;
1738
1739 // Cannot do without the broadcast if overall the rank increases.
1740 if (extractRank > inputRank)
1741 return Value();
1742
1743 // The above condition guarantees that input is a vector.
1744 assert(inputType && "input must be a vector type because of previous checks");
1745 ArrayRef<int64_t> inputShape = inputType.getShape();
1746
1747 // In the case where there is a broadcast dimension in the suffix, it is not
1748 // possible to replace extract(broadcast(X)) with extract(X). Example:
1749 //
1750 // broadcast extract
1751 // (1) --------> (3,4) ------> (4)
1752 if (extractType &&
1753 extractType.getShape() != inputShape.take_back(extractRank))
1754 return Value();
1755
1756 // Replace extract(broadcast(X)) with extract(X).
1757 // First, determine the new extraction position.
1758 unsigned deltaOverall = inputRank - extractRank;
1759 unsigned deltaBroadcast = broadcastRank - inputRank;
1760 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1761 SmallVector<OpFoldResult> newPositions(deltaOverall);
1762 IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1763 for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1764 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1765 }
1766 auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1767 extractOp->setOperands(
1768 llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1769 extractOp.setStaticPosition(staticPos);
1770 return extractOp.getResult();
1771}
1772
1773/// Fold extractOp coming from ShuffleOp.
1774///
1775/// Example:
1776///
1777/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1778/// : vector<8xf32>, vector<8xf32>
1779/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1780/// ->
1781/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1782///
1783static Value foldExtractFromShuffle(ExtractOp extractOp) {
1784 // Dynamic positions are not folded as the resulting code would be more
1785 // complex than the input code.
1786 if (extractOp.hasDynamicPosition())
1787 return Value();
1788
1789 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1790 if (!shuffleOp)
1791 return Value();
1792
1793 // TODO: 0-D or multi-dimensional vectors not supported yet.
1794 if (shuffleOp.getResultVectorType().getRank() != 1)
1795 return Value();
1796
1797 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1798 auto shuffleMask = shuffleOp.getMask();
1799 int64_t extractIdx = extractOp.getStaticPosition()[0];
1800 int64_t shuffleIdx = shuffleMask[extractIdx];
1801
1802 // Find the shuffled vector to extract from based on the shuffle index.
1803 if (shuffleIdx < inputVecSize) {
1804 extractOp.setOperand(0, shuffleOp.getV1());
1805 extractOp.setStaticPosition({shuffleIdx});
1806 } else {
1807 extractOp.setOperand(0, shuffleOp.getV2());
1808 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1809 }
1810
1811 return extractOp.getResult();
1812}
1813
1814// Fold extractOp with source coming from ShapeCast op.
1815static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1816 // TODO: Canonicalization for dynamic position not implemented yet.
1817 if (extractOp.hasDynamicPosition())
1818 return Value();
1819
1820 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1821 if (!shapeCastOp)
1822 return Value();
1823
1824 // Get the nth dimension size starting from lowest dimension.
1825 auto getDimReverse = [](VectorType type, int64_t n) {
1826 return type.getShape().take_back(n + 1).front();
1827 };
1828 int64_t destinationRank =
1829 llvm::isa<VectorType>(extractOp.getType())
1830 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1831 : 0;
1832 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1833 return Value();
1834 if (destinationRank > 0) {
1835 auto destinationType =
1836 llvm::cast<VectorType>(extractOp.getResult().getType());
1837 for (int64_t i = 0; i < destinationRank; i++) {
1838 // The lowest dimension of the destination must match the lowest
1839 // dimension of the shapecast op source.
1840 // TODO: This case could be support in a canonicalization pattern.
1841 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1842 getDimReverse(destinationType, i))
1843 return Value();
1844 }
1845 }
1846 // Extract the strides associated with the extract op vector source. Then use
1847 // this to calculate a linearized position for the extract.
1848 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1849 std::reverse(extractedPos.begin(), extractedPos.end());
1851 int64_t stride = 1;
1852 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1853 strides.push_back(stride);
1854 stride *=
1855 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1856 }
1857
1858 int64_t position = linearize(extractedPos, strides);
1859 // Then extract the strides associated to the shapeCast op vector source and
1860 // delinearize the position using those strides.
1861 SmallVector<int64_t, 4> newStrides;
1862 int64_t numDimension =
1863 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1864 stride = 1;
1865 for (int64_t i = 0; i < numDimension; i++) {
1866 newStrides.push_back(stride);
1867 stride *=
1868 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1869 }
1870 std::reverse(newStrides.begin(), newStrides.end());
1871 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1872 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1873 OpBuilder b(extractOp.getContext());
1874 extractOp.setStaticPosition(newPosition);
1875 extractOp.setOperand(0, shapeCastOp.getSource());
1876 return extractOp.getResult();
1877}
1878
1879/// Fold an ExtractOp from ExtractStridedSliceOp.
1880static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1881 // TODO: Canonicalization for dynamic position not implemented yet.
1882 if (extractOp.hasDynamicPosition())
1883 return Value();
1884
1885 auto extractStridedSliceOp =
1886 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1887 if (!extractStridedSliceOp)
1888 return Value();
1889
1890 // 0-D vectors not supported.
1891 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1892 if (hasZeroDimVectors(extractStridedSliceOp))
1893 return Value();
1894
1895 // Return if 'extractStridedSliceOp' has non-unit strides.
1896 if (extractStridedSliceOp.hasNonUnitStrides())
1897 return Value();
1898
1899 // Trim offsets for dimensions fully extracted.
1900 auto sliceOffsets =
1901 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1902 while (!sliceOffsets.empty()) {
1903 size_t lastOffset = sliceOffsets.size() - 1;
1904 if (sliceOffsets.back() != 0 ||
1905 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1906 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1907 break;
1908 sliceOffsets.pop_back();
1909 }
1910 unsigned destinationRank = 0;
1911 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1912 destinationRank = vecType.getRank();
1913 // The dimensions of the result need to be untouched by the
1914 // extractStridedSlice op.
1915 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1916 sliceOffsets.size())
1917 return Value();
1918
1919 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1920 assert(extractedPos.size() >= sliceOffsets.size());
1921 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1922 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1923 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1924
1925 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1926 OpBuilder b(extractOp.getContext());
1927 extractOp.setStaticPosition(extractedPos);
1928 return extractOp.getResult();
1929}
1930
1931/// Fold extract_op fed from a chain of insertStridedSlice ops.
1932static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1933 // TODO: Canonicalization for dynamic position not implemented yet.
1934 if (extractOp.hasDynamicPosition())
1935 return Value();
1936
1937 int64_t destinationRank =
1938 llvm::isa<VectorType>(extractOp.getType())
1939 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1940 : 0;
1941 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1942 if (!insertOp)
1943 return Value();
1944
1945 // 0-D vectors not supported.
1946 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1947 if (hasZeroDimVectors(insertOp))
1948 return Value();
1949
1950 while (insertOp) {
1951 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1952 insertOp.getSourceVectorType().getRank();
1953 if (destinationRank > insertOp.getSourceVectorType().getRank())
1954 return Value();
1955 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1956 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1957
1958 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1959 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1960 }))
1961 return Value();
1962 bool disjoint = false;
1963 SmallVector<int64_t, 4> offsetDiffs;
1964 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1965 int64_t start = insertOffsets[dim];
1966 int64_t size =
1967 (dim < insertRankDiff)
1968 ? 1
1969 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1970 int64_t end = start + size;
1971 int64_t offset = extractOffsets[dim];
1972 // Check if the start of the extract offset is in the interval inserted.
1973 if (start <= offset && offset < end) {
1974 if (dim >= insertRankDiff)
1975 offsetDiffs.push_back(offset - start);
1976 continue;
1977 }
1978 disjoint = true;
1979 break;
1980 }
1981 // The extract element chunk overlap with the vector inserted.
1982 if (!disjoint) {
1983 // If any of the inner dimensions are only partially inserted we have a
1984 // partial overlap.
1985 int64_t srcRankDiff =
1986 insertOp.getSourceVectorType().getRank() - destinationRank;
1987 for (int64_t i = 0; i < destinationRank; i++) {
1988 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1989 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1990 insertRankDiff))
1991 return Value();
1992 }
1993 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
1994 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1995 OpBuilder b(extractOp.getContext());
1996 extractOp.setStaticPosition(offsetDiffs);
1997 return extractOp.getResult();
1998 }
1999 // If the chunk extracted is disjoint from the chunk inserted, keep
2000 // looking in the insert chain.
2001 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2002 }
2003 return Value();
2004}
2005
2006/// Try to fold the extraction of a scalar from a vector defined by
2007/// vector.from_elements. E.g.:
2008///
2009/// %0 = vector.from_elements %a, %b : vector<2xf32>
2010/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2011/// ==> fold to %a
2012static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2013 // Dynamic extractions cannot be folded.
2014 if (extractOp.hasDynamicPosition())
2015 return {};
2016
2017 // Look for extract(from_elements).
2018 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2019 if (!fromElementsOp)
2020 return {};
2021
2022 // Scalable vectors are not supported.
2023 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2024 if (vecType.isScalable())
2025 return {};
2026
2027 // Only extractions of scalars are supported.
2028 int64_t rank = vecType.getRank();
2029 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2030 if (extractOp.getType() != vecType.getElementType())
2031 return {};
2032 assert(static_cast<int64_t>(indices.size()) == rank &&
2033 "unexpected number of indices");
2034
2035 // Compute flattened/linearized index and fold to operand.
2036 int flatIndex = 0;
2037 int stride = 1;
2038 for (int i = rank - 1; i >= 0; --i) {
2039 flatIndex += indices[i] * stride;
2040 stride *= vecType.getDimSize(i);
2041 }
2042 return fromElementsOp.getElements()[flatIndex];
2043}
2044
2045/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2046/// then fold it.
2047template <typename OpType, typename AdaptorType>
2048static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2049 SmallVectorImpl<Value> &operands) {
2050 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2051 OperandRange dynamicPosition = op.getDynamicPosition();
2052 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2054 if constexpr (std::is_same_v<OpType, ExtractOp>)
2055 vectorShape = op.getSourceVectorType().getShape();
2056 else
2057 vectorShape = op.getDestVectorType().getShape();
2058
2059 // If the dynamic operands is empty, it is returned directly.
2060 if (!dynamicPosition.size())
2061 return {};
2062
2063 // `index` is used to iterate over the `dynamicPosition`.
2064 unsigned index = 0;
2065
2066 // `opChange` is a flag. If it is true, it means to update `op` in place.
2067 bool opChange = false;
2068 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2069 if (ShapedType::isStatic(staticPosition[i]))
2070 continue;
2071 Attribute positionAttr = dynamicPositionAttr[index];
2072 Value position = dynamicPosition[index++];
2073 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2074 int64_t value = attr.getInt();
2075 // Do not fold if the value is out of bounds (-1 signifies a poison
2076 // value rather than OOB index).
2077 if (value >= -1 && value < vectorShape[i]) {
2078 staticPosition[i] = attr.getInt();
2079 opChange = true;
2080 continue;
2081 }
2082 }
2083 operands.push_back(position);
2084 }
2085
2086 if (opChange) {
2087 op.setStaticPosition(staticPosition);
2088 op.getOperation()->setOperands(operands);
2089 // Return the original result to indicate an in-place folding happened.
2090 return op.getResult();
2091 }
2092 return {};
2093}
2094
2095/// Fold an insert or extract operation into an poison value when a poison index
2096/// is found at any dimension of the static position.
2098 ArrayRef<int64_t> staticPos,
2099 int64_t poisonVal) {
2100 if (!is_contained(staticPos, poisonVal))
2101 return {};
2102
2103 return ub::PoisonAttr::get(context);
2104}
2105
2106/// Fold a vector extract from is a poison source.
2108 if (isa_and_nonnull<ub::PoisonAttr>(srcAttr))
2109 return srcAttr;
2110
2111 return {};
2112}
2113
2114/// Fold a vector extract extracting from a DenseElementsAttr.
2116 Attribute srcAttr) {
2117 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2118 if (!denseAttr) {
2119 return {};
2120 }
2121
2122 if (denseAttr.isSplat()) {
2123 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2124 if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2125 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2126 return newAttr;
2127 }
2128
2129 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2130 if (vecTy.isScalable())
2131 return {};
2132
2133 if (extractOp.hasDynamicPosition()) {
2134 return {};
2135 }
2136
2137 // Materializing subsets of a large constant array can generally lead to
2138 // explosion in IR size because of different combination of subsets that
2139 // can exist. However, vector.extract is a restricted form of subset
2140 // extract where you can only extract non-overlapping (or the same) subset for
2141 // a given rank of the subset. Because of this property, the IR size can only
2142 // increase at most by `rank * size(array)` from a single constant array being
2143 // extracted by multiple extracts.
2144
2145 // Calculate the linearized position of the continuous chunk of elements to
2146 // extract.
2147 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2148 copy(extractOp.getStaticPosition(), completePositions.begin());
2149 int64_t startPos =
2150 linearize(completePositions, computeStrides(vecTy.getShape()));
2151 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2152
2153 TypedAttr newAttr;
2154 if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2155 SmallVector<Attribute> elementValues(
2156 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2157 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2158 } else {
2159 newAttr = *denseValuesBegin;
2160 }
2161
2162 return newAttr;
2163}
2164
2165OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2166 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2167 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2168 // mismatch).
2169 if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2170 return getSource();
2171 if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2172 return res;
2173 // Fold `arith.constant` indices into the `vector.extract` operation.
2174 // Do not stop here as this fold may enable subsequent folds that require
2175 // constant indices.
2176 SmallVector<Value> operands = {getSource()};
2177 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2178
2179 if (auto res = foldPoisonIndexInsertExtractOp(
2180 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2181 return res;
2182 if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2183 return res;
2184 if (succeeded(foldExtractOpFromExtractChain(*this)))
2185 return getResult();
2186 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2187 return res;
2188 if (auto res = foldExtractFromBroadcast(*this))
2189 return res;
2190 if (auto res = foldExtractFromShuffle(*this))
2191 return res;
2192 if (auto res = foldExtractFromShapeCast(*this))
2193 return res;
2194 if (auto val = foldExtractFromExtractStrided(*this))
2195 return val;
2196 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2197 return val;
2198 if (auto val = foldScalarExtractFromFromElements(*this))
2199 return val;
2200
2201 return inplaceFolded;
2202}
2203
2204namespace {
2205
2206// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2207class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2208public:
2209 using Base::Base;
2210
2211 LogicalResult matchAndRewrite(ExtractOp extractOp,
2212 PatternRewriter &rewriter) const override {
2213
2214 Operation *defOp = extractOp.getSource().getDefiningOp();
2215 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2216 if (!defOp || !isBroadcastLike(defOp) || !outType)
2217 return failure();
2218
2219 Value source = defOp->getOperand(0);
2220 if (isBroadcastableTo(source.getType(), outType) !=
2221 BroadcastableToResult::Success)
2222 return failure();
2223
2224 rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2225 return success();
2226 }
2227};
2228
2229// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2230class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2231public:
2232 using Base::Base;
2233
2234 LogicalResult matchAndRewrite(ExtractOp extractOp,
2235 PatternRewriter &rewriter) const override {
2236 auto createMaskOp =
2237 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2238 if (!createMaskOp)
2239 return failure();
2240
2241 VectorType extractedMaskType =
2242 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2243
2244 if (!extractedMaskType)
2245 return failure();
2246
2247 auto maskOperands = createMaskOp.getOperands();
2248 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2249 VectorType maskType = createMaskOp.getVectorType();
2250
2251 bool containsUnknownDims = false;
2252 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2253
2254 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2255 dimIdx++) {
2256 int64_t pos = extractOpPos[dimIdx];
2257 Value operand = maskOperands[dimIdx];
2258 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2259 if (!constantOp) {
2260 // Bounds of this dim unknown.
2261 containsUnknownDims = true;
2262 continue;
2263 }
2264
2265 int64_t createMaskBound =
2266 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2267
2268 if (pos != ShapedType::kDynamic) {
2269 // If any position is outside the range from the `create_mask`, then the
2270 // extracted mask will be all-false.
2271 allFalse |= pos >= createMaskBound;
2272 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2273 // This dim is not all-true and since this is a dynamic index we don't
2274 // know if the extraction is within the true or false region.
2275 // Note: Zero dims have already handled via getMaskFormat().
2276 containsUnknownDims = true;
2277 }
2278 }
2279
2280 if (allFalse) {
2281 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2282 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2283 } else if (!containsUnknownDims) {
2284 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2285 extractOp, extractedMaskType,
2286 maskOperands.drop_front(extractOpPos.size()));
2287 } else {
2288 return failure();
2289 }
2290 return success();
2291 }
2292};
2293
2294// Folds extract(shape_cast(..)) into shape_cast when the total element count
2295// does not change.
2296LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2297 PatternRewriter &rewriter) {
2298 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2299 if (!castOp)
2300 return failure();
2301
2302 VectorType sourceType = castOp.getSourceVectorType();
2303 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2304 if (!targetType)
2305 return failure();
2306
2307 if (sourceType.getNumElements() != targetType.getNumElements())
2308 return failure();
2309
2310 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2311 castOp.getSource());
2312 return success();
2313}
2314
2315/// Try to canonicalize the extraction of a subvector from a vector defined by
2316/// vector.from_elements. E.g.:
2317///
2318/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2319/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2320/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2321LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2322 PatternRewriter &rewriter) {
2323 // Dynamic positions are not supported.
2324 if (extractOp.hasDynamicPosition())
2325 return failure();
2326
2327 // Scalar extracts are handled by the folder.
2328 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2329 if (!resultType)
2330 return failure();
2331
2332 // Look for extracts from a from_elements op.
2333 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2334 if (!fromElementsOp)
2335 return failure();
2336 VectorType inputType = fromElementsOp.getType();
2337
2338 // Scalable vectors are not supported.
2339 if (resultType.isScalable() || inputType.isScalable())
2340 return failure();
2341
2342 // Compute the position of first extracted element and flatten/linearize the
2343 // position.
2344 SmallVector<int64_t> firstElementPos =
2345 llvm::to_vector(extractOp.getStaticPosition());
2346 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2347 int flatIndex = 0;
2348 int stride = 1;
2349 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2350 flatIndex += firstElementPos[i] * stride;
2351 stride *= inputType.getDimSize(i);
2352 }
2353
2354 // Replace the op with a smaller from_elements op.
2355 rewriter.replaceOpWithNewOp<FromElementsOp>(
2356 extractOp, resultType,
2357 fromElementsOp.getElements().slice(flatIndex,
2358 resultType.getNumElements()));
2359 return success();
2360}
2361
2362} // namespace
2363
2364void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2365 MLIRContext *context) {
2366 results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
2367 results.add(foldExtractFromShapeCastToShapeCast);
2368 results.add(foldExtractFromFromElements);
2369}
2370
2372 SmallVectorImpl<int64_t> &results) {
2373 for (auto attr : arrayAttr)
2374 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2375}
2376
2377//===----------------------------------------------------------------------===//
2378// FmaOp
2379//===----------------------------------------------------------------------===//
2380
2381std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2382 return llvm::to_vector<4>(getVectorType().getShape());
2383}
2384
2385//===----------------------------------------------------------------------===//
2386// ToElementsOp
2387//===----------------------------------------------------------------------===//
2388
2389/// Returns true if all the `operands` are defined by `defOp`.
2390/// Otherwise, returns false.
2391static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2392 if (operands.empty())
2393 return false;
2394
2395 return llvm::all_of(operands, [&](Value operand) {
2396 Operation *currentDef = operand.getDefiningOp();
2397 return currentDef == defOp;
2398 });
2399}
2400
2401/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2402/// (%e0, %e1, ...). For example:
2403///
2404/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2405/// %1:3 = vector.to_elements %0 : vector<3xf32>
2406/// user_op %1#0, %1#1, %1#2
2407///
2408/// becomes:
2409///
2410/// user_op %a, %b, %c
2411///
2412static LogicalResult
2413foldToElementsFromElements(ToElementsOp toElementsOp,
2415 auto fromElementsOp =
2416 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2417 if (!fromElementsOp)
2418 return failure();
2419
2420 llvm::append_range(results, fromElementsOp.getElements());
2421 return success();
2422}
2423
2424/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2425///
2426/// Example:
2427/// %b = vector.broadcast %x : i32 to vector<3xf32>
2428/// %e:3 = vector.to_elements %b : vector<3xf32>
2429/// user_op %e#0, %e#1, %e#2
2430/// becomes:
2431/// user_op %x, %x, %x
2432///
2433/// The vector source case is handled by a canonicalization pattern.
2434static LogicalResult
2435foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2437 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2438 if (!bcastOp)
2439 return failure();
2440 // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2441 if (isa<VectorType>(bcastOp.getSource().getType()))
2442 return failure();
2443
2444 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2445
2446 Value scalar = bcastOp.getSource();
2447 results.assign(resultVecType.getNumElements(), scalar);
2448 return success();
2449}
2450
2451LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2452 SmallVectorImpl<OpFoldResult> &results) {
2453 if (succeeded(foldToElementsFromElements(*this, results)))
2454 return success();
2455 return foldToElementsOfBroadcast(*this, results);
2456}
2457
2458LogicalResult
2459ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2460 ToElementsOp::Adaptor adaptor,
2461 SmallVectorImpl<Type> &inferredReturnTypes) {
2462 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2463 Type elType = vecType.getElementType();
2464 inferredReturnTypes.append(vecType.getNumElements(), elType);
2465 return success();
2466}
2467
2468/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2469/// vector.
2470/// - Build `vector.to_elements %v` and remap each destination element to the
2471/// corresponding source element using broadcast rules (match or 1 →
2472/// replicate).
2473///
2474/// Example:
2475/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2476/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2477/// becomes:
2478/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2479/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2480/// // %src_elems#1, %src_elems#0, %src_elems#1
2481struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2482 using Base::Base;
2483
2484 LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2485 PatternRewriter &rewriter) const override {
2486 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2487 if (!bcastOp)
2488 return failure();
2489
2490 // Only handle broadcasts from a vector source here.
2491 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2492 if (!srcType)
2493 return failure();
2494
2495 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2496
2497 ArrayRef<int64_t> dstShape = dstType.getShape();
2498 ArrayRef<int64_t> srcShape = srcType.getShape();
2499
2500 int64_t dstRank = dstShape.size();
2501 int64_t srcRank = srcShape.size();
2502
2503 // Create elements for the broadcast source vector.
2504 auto srcElems = vector::ToElementsOp::create(
2505 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2506
2507 int64_t dstCount = llvm::product_of(dstShape);
2508
2509 SmallVector<Value> replacements;
2510 replacements.reserve(dstCount);
2511
2512 // For each element of the destination, determine which element of the
2513 // source should be used. We walk all destination positions using a single
2514 // counter, decode it into per-dimension indices, then build the matching
2515 // source position: use the same index where sizes match, and use 0 where
2516 // the source size is 1 (replication). This mapping is needed so we can
2517 // replace each result of to_elements with the corresponding element from
2518 // the broadcast source.
2519 // Inner-dimension stretch example:
2520 // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2521 // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2522 // becomes:
2523 // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2524 // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2525 // // %src_elems#1, %src_elems#0, %src_elems#1,
2526 // // %src_elems#2, %src_elems#3, %src_elems#2,
2527 // // %src_elems#3, %src_elems#2, %src_elems#3
2528
2529 // Row-major strides for the destination shape.
2530 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2531 // Row-major strides for the source shape.
2532 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2533 SmallVector<int64_t> dstIdx(dstRank);
2534 SmallVector<int64_t> srcIdx(srcRank);
2535 for (int64_t lin = 0; lin < dstCount; ++lin) {
2536 // Convert linear destination index to per-dimension indices.
2537 dstIdx = delinearize(lin, dstStrides);
2538 for (int64_t k = 0; k < srcRank; ++k)
2539 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2540 // Convert per-dimension source indices back to a linear index.
2541 int64_t srcLin = linearize(srcIdx, srcStrides);
2542 replacements.push_back(srcElems.getResult(srcLin));
2543 }
2544
2545 rewriter.replaceOp(toElementsOp, replacements);
2546 return success();
2547 }
2548};
2549
2550void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2551 MLIRContext *context) {
2552 results.add<ToElementsOfBroadcast>(context);
2553}
2554
2555//===----------------------------------------------------------------------===//
2556// FromElementsOp
2557//===----------------------------------------------------------------------===//
2558
2559/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2560///
2561/// Case #1: Input and output vectors are the same.
2562///
2563/// %0:3 = vector.to_elements %a : vector<3xf32>
2564/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2565/// user_op %1
2566///
2567/// becomes:
2568///
2569/// user_op %a
2570///
2571static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2572 OperandRange fromElemsOperands = fromElementsOp.getElements();
2573 if (fromElemsOperands.empty())
2574 return {};
2575
2576 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2577 if (!toElementsOp)
2578 return {};
2579
2580 if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2581 return {};
2582
2583 // Case #1: Input and output vectors are the same. Forward the input vector.
2584 Value toElementsInput = toElementsOp.getSource();
2585 if (fromElementsOp.getType() == toElementsInput.getType() &&
2586 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2587 return toElementsInput;
2588 }
2589
2590 // TODO: Support cases with different input and output shapes and different
2591 // number of elements.
2592
2593 return {};
2594}
2595
2596/// Fold vector.from_elements to a constant when all operands are constants.
2597/// Example:
2598/// %c1 = arith.constant 1 : i32
2599/// %c2 = arith.constant 2 : i32
2600/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2601/// =>
2602/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2603///
2604static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2605 ArrayRef<Attribute> elements) {
2606 // Check for null or poison attributes before any processing.
2607 if (llvm::any_of(elements, [](Attribute attr) {
2608 return !attr || isa<ub::PoisonAttrInterface>(attr);
2609 }))
2610 return {};
2611
2612 // DenseElementsAttr only supports int/index/float/complex types.
2613 auto destVecType = fromElementsOp.getDest().getType();
2614 auto destEltType = destVecType.getElementType();
2615 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2616 return {};
2617
2618 // Constant attributes might have a different type than the return type.
2619 // Convert them before creating the dense elements attribute.
2620 auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2621 return convertNumericAttr(attr, destEltType);
2622 });
2623
2624 return DenseElementsAttr::get(destVecType, convertedElements);
2625}
2626
2627OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2628 if (auto res = foldFromElementsToElements(*this))
2629 return res;
2630 if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2631 return res;
2632
2633 return {};
2634}
2635
2636/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2637/// same. Example:
2638/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2639/// =>
2640/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2641static LogicalResult
2642rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2643 PatternRewriter &rewriter) {
2644 if (!llvm::all_equal(fromElementsOp.getElements()))
2645 return failure();
2646 rewriter.replaceOpWithNewOp<BroadcastOp>(
2647 fromElementsOp, fromElementsOp.getType(),
2648 fromElementsOp.getElements().front());
2649 return success();
2650}
2651
2652/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2653/// on a single extract. Example:
2654/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2655/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2656/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2657///
2658/// becomes
2659/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2660/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2661///
2662/// The requirements for this to be valid are
2663///
2664/// i) The elements are extracted from the same vector (%source).
2665///
2666/// ii) The elements form a suffix of %source. Specifically, the number
2667/// of elements is the same as the product of the last N dimension sizes
2668/// of %source, for some N.
2669///
2670/// iii) The elements are extracted contiguously in ascending order.
2671
2672class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2673
2674 using Base::Base;
2675
2676 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2677 PatternRewriter &rewriter) const override {
2678
2679 // Handled by `rewriteFromElementsAsBroadcast`.
2680 if (fromElements.getType().getNumElements() == 1)
2681 return failure();
2682
2683 // The common source that all elements are extracted from, if one exists.
2685 // The position of the combined extract operation, if one is created.
2686 ArrayRef<int64_t> combinedPosition;
2687 // The expected index of extraction of the current element in the loop, if
2688 // elements are extracted contiguously in ascending order.
2689 SmallVector<int64_t> expectedPosition;
2690
2691 for (auto [insertIndex, element] :
2692 llvm::enumerate(fromElements.getElements())) {
2693
2694 // Check that the element is from a vector.extract operation.
2695 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2696 if (!extractOp) {
2697 return rewriter.notifyMatchFailure(fromElements,
2698 "element not from vector.extract");
2699 }
2700
2701 // Check condition (i) by checking that all elements have the same source
2702 // as the first element.
2703 if (insertIndex == 0) {
2704 source = extractOp.getSource();
2705 } else if (extractOp.getSource() != source) {
2706 return rewriter.notifyMatchFailure(fromElements,
2707 "element from different vector");
2708 }
2709
2710 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2711 int64_t rank = position.size();
2712 assert(rank == source.getType().getRank() &&
2713 "scalar extract must have full rank position");
2714
2715 // Check condition (ii) by checking that the position that the first
2716 // element is extracted from has sufficient trailing 0s. For example, in
2717 //
2718 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2719 // [...]
2720 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2721 //
2722 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2723 // elements, which is the number of elements of %n, so this is valid.
2724 if (insertIndex == 0) {
2725 const int64_t numElms = fromElements.getType().getNumElements();
2726 int64_t numSuffixElms = 1;
2727 int64_t index = rank;
2728 while (index > 0 && position[index - 1] == 0 &&
2729 numSuffixElms < numElms) {
2730 numSuffixElms *= source.getType().getDimSize(index - 1);
2731 --index;
2732 }
2733 if (numSuffixElms != numElms) {
2734 return rewriter.notifyMatchFailure(
2735 fromElements, "elements do not form a suffix of source");
2736 }
2737 expectedPosition = llvm::to_vector(position);
2738 combinedPosition = position.drop_back(rank - index);
2739 }
2740
2741 // Check condition (iii).
2742 else if (expectedPosition != position) {
2743 return rewriter.notifyMatchFailure(
2744 fromElements, "elements not in ascending order (static order)");
2745 }
2746 increment(expectedPosition, source.getType().getShape());
2747 }
2748
2749 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2750 fromElements.getLoc(), source, combinedPosition);
2751
2752 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2753 fromElements, fromElements.getType(), extracted);
2754
2755 return success();
2756 }
2757
2758 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2759 static void increment(MutableArrayRef<int64_t> indices,
2761 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2762 indices[dim] += 1;
2763 if (indices[dim] < shape[dim])
2764 break;
2765 indices[dim] = 0;
2766 }
2767 }
2768};
2769
2770void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2771 MLIRContext *context) {
2773 results.add<FromElementsToShapeCast>(context);
2774}
2775
2776//===----------------------------------------------------------------------===//
2777// BroadcastOp
2778//===----------------------------------------------------------------------===//
2779
2780void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2781 SetIntRangeFn setResultRanges) {
2782 setResultRanges(getResult(), argRanges.front());
2783}
2784
2785std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2786 return llvm::to_vector<4>(getResultVectorType().getShape());
2787}
2788
2789/// Return the dimensions of the result vector that were formerly ones in the
2790/// source tensor and thus correspond to "dim-1" broadcasting.
2791static llvm::SetVector<int64_t>
2793 ArrayRef<int64_t> dstShape) {
2794 int64_t rankDiff = dstShape.size() - srcShape.size();
2795 int64_t dstDim = rankDiff;
2797 for (auto [s1, s2] :
2798 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2799 if (s1 != s2) {
2800 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2801 res.insert(dstDim);
2802 }
2803 ++dstDim;
2804 }
2805 return res;
2806}
2807
2808llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2809 // Scalar broadcast is without any unit dim broadcast.
2810 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2811 if (!srcVectorType)
2812 return {};
2813 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2814 getResultVectorType().getShape());
2815}
2816
2817/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2818/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2819/// This requires (and asserts) that the broadcast is free of "dim-1"
2820/// broadcasting.
2821/// Since vector.broadcast only allows expanding leading dimensions, an extra
2822/// vector.transpose may be inserted to make the broadcast possible.
2823/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
2824/// the helper will assert. This means:
2825/// 1. `dstShape` must not be empty.
2826/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
2827/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
2828// must match the `value` shape.
2829Value BroadcastOp::createOrFoldBroadcastOp(
2830 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
2831 const llvm::SetVector<int64_t> &broadcastedDims) {
2832 assert(!dstShape.empty() && "unexpected empty dst shape");
2833
2834 // Well-formedness check.
2835 SmallVector<int64_t> checkShape;
2836 for (int i = 0, e = dstShape.size(); i < e; ++i) {
2837 if (broadcastedDims.contains(i))
2838 continue;
2839 checkShape.push_back(dstShape[i]);
2840 }
2841 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
2842 "ill-formed broadcastedDims contains values not confined to "
2843 "destVectorShape");
2844
2845 Location loc = value.getLoc();
2846 Type elementType = getElementTypeOrSelf(value.getType());
2847 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
2848 VectorType dstVectorType = VectorType::get(dstShape, elementType);
2849
2850 // Step 2. If scalar -> dstShape broadcast, just do it.
2851 if (!srcVectorType) {
2852 assert(checkShape.empty() &&
2853 "ill-formed createOrFoldBroadcastOp arguments");
2854 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
2855 }
2856
2857 assert(srcVectorType.getShape().equals(checkShape) &&
2858 "ill-formed createOrFoldBroadcastOp arguments");
2859
2860 // Step 3. Since vector.broadcast only allows creating leading dims,
2861 // vector -> dstShape broadcast may require a transpose.
2862 // Traverse the dims in order and construct:
2863 // 1. The leading entries of the broadcastShape that is guaranteed to be
2864 // achievable by a simple broadcast.
2865 // 2. The induced permutation for the subsequent vector.transpose that will
2866 // bring us from `broadcastShape` back to he desired `dstShape`.
2867 // If the induced permutation is not the identity, create a vector.transpose.
2868 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
2869 broadcastShape.reserve(dstShape.size());
2870 // Consider the example:
2871 // srcShape = 2x4
2872 // dstShape = 1x2x3x4x5
2873 // broadcastedDims = [0, 2, 4]
2874 //
2875 // We want to build:
2876 // broadcastShape = 1x3x5x2x4
2877 // permutation = [0, 2, 4, 1, 3]
2878 // ---V--- -----V-----
2879 // leading broadcast part src shape part
2880 //
2881 // Note that the trailing dims of broadcastShape are exactly the srcShape
2882 // by construction.
2883 // nextSrcShapeDim is used to keep track of where in the permutation the
2884 // "src shape part" occurs.
2885 int64_t nextSrcShapeDim = broadcastedDims.size();
2886 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
2887 if (broadcastedDims.contains(i)) {
2888 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
2889 // bring it to the head of the broadcastShape.
2890 // It will need to be permuted back from `broadcastShape.size() - 1` into
2891 // position `i`.
2892 broadcastShape.push_back(dstShape[i]);
2893 permutation[i] = broadcastShape.size() - 1;
2894 } else {
2895 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
2896 // shape and needs to be permuted into position `i`.
2897 // Don't touch `broadcastShape` here, the whole srcShape will be
2898 // appended after.
2899 permutation[i] = nextSrcShapeDim++;
2900 }
2901 }
2902 // 3.c. Append the srcShape.
2903 llvm::append_range(broadcastShape, srcVectorType.getShape());
2904
2905 // Ensure there are no "dim-1" broadcasts.
2906 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
2907 .empty() &&
2908 "unexpected \"dim-1\" broadcast");
2909
2910 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
2911 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
2912 vector::BroadcastableToResult::Success &&
2913 "must be broadcastable");
2914 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
2915 // Step 4. If we find any dimension that indeed needs to be permuted,
2916 // immediately return a new vector.transpose.
2917 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
2918 if (permutation[i] != i)
2919 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
2920 // Otherwise return res.
2921 return res;
2922}
2923
2925 Type srcType, VectorType dstVectorType,
2926 std::pair<VectorDim, VectorDim> *mismatchingDims) {
2927 // Broadcast scalar to vector of the same element type.
2928 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
2929 srcType == getElementTypeOrSelf(dstVectorType))
2931 // From now on, only vectors broadcast.
2932 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
2933 if (!srcVectorType)
2935
2936 int64_t srcRank = srcVectorType.getRank();
2937 int64_t dstRank = dstVectorType.getRank();
2938 if (srcRank > dstRank)
2940 // Source has an exact match or singleton value for all trailing dimensions
2941 // (all leading dimensions are simply duplicated).
2942 int64_t lead = dstRank - srcRank;
2943 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
2944 // Have mismatching dims (in the sense of vector.broadcast semantics) been
2945 // encountered?
2946 bool foundMismatchingDims = false;
2947
2948 // Check fixed-width dims.
2949 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
2950 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
2951 if (srcDim != 1 && srcDim != dstDim)
2952 foundMismatchingDims = true;
2953
2954 // Check scalable flags.
2955 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
2956 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
2957 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
2958 // 1 -> [N] is fine, everything else should be rejected when mixing
2959 // fixed-width and scalable dims
2960 (srcDimScalableFlag != dstDimScalableFlag &&
2961 (srcDim != 1 || srcDimScalableFlag)))
2962 foundMismatchingDims = true;
2963
2964 if (foundMismatchingDims) {
2965 if (mismatchingDims != nullptr) {
2966 mismatchingDims->first.dim = srcDim;
2967 mismatchingDims->first.isScalable = srcDimScalableFlag;
2968
2969 mismatchingDims->second.dim = dstDim;
2970 mismatchingDims->second.isScalable = dstDimScalableFlag;
2971 }
2973 }
2974 }
2975
2977}
2978
2979LogicalResult BroadcastOp::verify() {
2980 std::pair<VectorDim, VectorDim> mismatchingDims;
2982 getSourceType(), getResultVectorType(), &mismatchingDims);
2984 return success();
2986 return emitOpError("source rank higher than destination rank");
2988 return emitOpError("dimension mismatch (")
2989 << (mismatchingDims.first.isScalable ? "[" : "")
2990 << mismatchingDims.first.dim
2991 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
2992 << (mismatchingDims.second.isScalable ? "[" : "")
2993 << mismatchingDims.second.dim
2994 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
2995 }
2997 return emitOpError("source type is not a vector");
2998 llvm_unreachable("unexpected vector.broadcast op error");
2999}
3000
3001// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
3002// with broadcast's result type and shape_cast only adds or removes ones in the
3003// leading dimensions.
3004static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
3005 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3006 if (!srcShapeCast)
3007 return failure();
3008
3009 VectorType srcType = srcShapeCast.getSourceVectorType();
3010 VectorType destType = broadcastOp.getResultVectorType();
3011 // Check type compatibility.
3012 if (vector::isBroadcastableTo(srcType, destType) !=
3014 return failure();
3015
3016 ArrayRef<int64_t> srcShape = srcType.getShape();
3017 ArrayRef<int64_t> shapecastShape =
3018 srcShapeCast.getResultVectorType().getShape();
3019 // Trailing dimensions should be the same if shape_cast only alters the
3020 // leading dimensions.
3021 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3022 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3023 shapecastShape.take_back(numTrailingDims)))
3024 return failure();
3025
3026 assert(all_of(srcShape.drop_back(numTrailingDims),
3027 [](int64_t E) { return E == 1; }) &&
3028 all_of(shapecastShape.drop_back(numTrailingDims),
3029 [](int64_t E) { return E == 1; }) &&
3030 "ill-formed shape_cast");
3031
3032 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3033 return success();
3034}
3035
3036OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3037 if (getSourceType() == getResultVectorType())
3038 return getSource();
3039 if (succeeded(foldBroadcastOfShapeCast(*this)))
3040 return getResult();
3041
3042 if (!adaptor.getSource())
3043 return {};
3044 auto vectorType = getResultVectorType();
3045 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3046 if (vectorType.getElementType() != attr.getType())
3047 return {};
3048 return DenseElementsAttr::get(vectorType, attr);
3049 }
3050 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3051 if (vectorType.getElementType() != attr.getType())
3052 return {};
3053 return DenseElementsAttr::get(vectorType, attr);
3054 }
3055 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3056 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3057 if (llvm::dyn_cast<ub::PoisonAttr>(adaptor.getSource()))
3058 return ub::PoisonAttr::get(getContext());
3059 return {};
3060}
3061
3062namespace {
3063
3064// Fold broadcast1(broadcast2(x)) into broadcast1(x).
3065struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3066 using Base::Base;
3067
3068 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3069 PatternRewriter &rewriter) const override {
3070 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3071 if (!srcBroadcast)
3072 return failure();
3073 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3074 broadcastOp.getResultVectorType(),
3075 srcBroadcast.getSource());
3076 return success();
3077 }
3078};
3079} // namespace
3080
3081void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3082 MLIRContext *context) {
3083 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
3084 // calling `populateCastAwayVectorLeadingOneDimPatterns`
3085 results.add<BroadcastFolder>(context);
3086}
3087
3088//===----------------------------------------------------------------------===//
3089// ShuffleOp
3090//===----------------------------------------------------------------------===//
3091
3092LogicalResult ShuffleOp::verify() {
3093 VectorType resultType = getResultVectorType();
3094 VectorType v1Type = getV1VectorType();
3095 VectorType v2Type = getV2VectorType();
3096 // Verify ranks.
3097 int64_t resRank = resultType.getRank();
3098 int64_t v1Rank = v1Type.getRank();
3099 int64_t v2Rank = v2Type.getRank();
3100 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3101 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3102 if (!wellFormed0DCase && !wellFormedNDCase)
3103 return emitOpError("rank mismatch");
3104
3105 // Verify all but leading dimension sizes.
3106 for (int64_t r = 1; r < v1Rank; ++r) {
3107 int64_t resDim = resultType.getDimSize(r);
3108 int64_t v1Dim = v1Type.getDimSize(r);
3109 int64_t v2Dim = v2Type.getDimSize(r);
3110 if (resDim != v1Dim || v1Dim != v2Dim)
3111 return emitOpError("dimension mismatch");
3112 }
3113 // Verify mask length.
3114 ArrayRef<int64_t> mask = getMask();
3115 int64_t maskLength = mask.size();
3116 if (maskLength <= 0)
3117 return emitOpError("invalid mask length");
3118 if (maskLength != resultType.getDimSize(0))
3119 return emitOpError("mask length mismatch");
3120 // Verify all indices.
3121 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3122 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3123 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3124 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3125 return emitOpError("mask index #") << (idx + 1) << " out of range";
3126 }
3127 return success();
3128}
3129
3130LogicalResult
3131ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
3132 ShuffleOp::Adaptor adaptor,
3133 SmallVectorImpl<Type> &inferredReturnTypes) {
3134 auto v1Type = llvm::cast<VectorType>(adaptor.getV1().getType());
3135 auto v1Rank = v1Type.getRank();
3136 // Construct resulting type: leading dimension matches mask
3137 // length, all trailing dimensions match the operands.
3138 SmallVector<int64_t, 4> shape;
3139 shape.reserve(v1Rank);
3140 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3141 // In the 0-D case there is no trailing shape to append.
3142 if (v1Rank > 0)
3143 llvm::append_range(shape, v1Type.getShape().drop_front());
3144 inferredReturnTypes.push_back(
3145 VectorType::get(shape, v1Type.getElementType()));
3146 return success();
3147}
3148
3149template <typename T>
3150static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3151 T expected = begin;
3152 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3153 return value == expected++;
3154 });
3155}
3156
3157OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3158 auto v1Type = getV1VectorType();
3159 auto v2Type = getV2VectorType();
3160
3161 assert(!v1Type.isScalable() && !v2Type.isScalable() &&
3162 "Vector shuffle does not support scalable vectors");
3163
3164 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3165 // but must be a canonicalization into a vector.broadcast.
3166 if (v1Type.getRank() == 0)
3167 return {};
3168
3169 // Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3170 auto mask = getMask();
3171 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3172 return getV1();
3173 // Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3174 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3175 return getV2();
3176
3177 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3178 if (!v1Attr || !v2Attr)
3179 return {};
3180
3181 // Fold shuffle poison, poison -> poison.
3182 bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
3183 bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
3184 if (isV1Poison && isV2Poison)
3185 return ub::PoisonAttr::get(getContext());
3186
3187 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
3188 // manipulation.
3189 if (v1Type.getRank() != 1)
3190 return {};
3191
3192 // Poison input attributes need special handling as they are not
3193 // DenseElementsAttr. If an index is poison, we select the first element of
3194 // the first non-poison input.
3195 SmallVector<Attribute> v1Elements, v2Elements;
3196 Attribute poisonElement;
3197 if (!isV2Poison) {
3198 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3199 if (!v2DenseAttr)
3200 return {};
3201 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3202 poisonElement = v2Elements[0];
3203 }
3204 if (!isV1Poison) {
3205 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3206 if (!v1DenseAttr)
3207 return {};
3208 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3209 poisonElement = v1Elements[0];
3210 }
3211
3212 SmallVector<Attribute> results;
3213 int64_t v1Size = v1Type.getDimSize(0);
3214 for (int64_t maskIdx : mask) {
3215 Attribute indexedElm;
3216 // TODO: Return a partial poison vector when supported by the UB dialect.
3217 if (maskIdx == ShuffleOp::kPoisonIndex) {
3218 indexedElm = poisonElement;
3219 } else {
3220 if (maskIdx < v1Size)
3221 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3222 else
3223 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3224 }
3225
3226 results.push_back(indexedElm);
3227 }
3228
3229 return DenseElementsAttr::get(getResultVectorType(), results);
3230}
3231
3232namespace {
3233
3234// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3235// to a broadcast.
3236struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3237 using Base::Base;
3238
3239 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3240 PatternRewriter &rewriter) const override {
3241 VectorType v1VectorType = shuffleOp.getV1VectorType();
3242 ArrayRef<int64_t> mask = shuffleOp.getMask();
3243 if (v1VectorType.getRank() > 0)
3244 return failure();
3245 if (mask.size() != 1)
3246 return failure();
3247 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3248 if (mask[0] == 0)
3249 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3250 shuffleOp.getV1());
3251 else
3252 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3253 shuffleOp.getV2());
3254 return success();
3255 }
3256};
3257
3258/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3259/// vector.broadcast with a scalar operand, return the scalar value that is
3260/// splatted. Otherwise return null.
3261///
3262/// Example:
3263///
3264/// scalar_source --> vector.broadcast --> value - return scalar_source
3265static Value getScalarSplatSource(Value value) {
3266 // Block argument:
3267 Operation *defOp = value.getDefiningOp();
3268 if (!defOp)
3269 return {};
3270
3271 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3272
3273 // Not broadcast (and not splat):
3274 if (!broadcast)
3275 return {};
3276
3277 // Broadcast of a vector:
3278 if (isa<VectorType>(broadcast.getSourceType()))
3279 return {};
3280
3281 // Broadcast of a scalar:
3282 return broadcast.getSource();
3283}
3284
3285/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3286class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3287public:
3288 using Base::Base;
3289
3290 LogicalResult matchAndRewrite(ShuffleOp op,
3291 PatternRewriter &rewriter) const override {
3292 Value splat = getScalarSplatSource(op.getV1());
3293 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3294 return failure();
3295
3296 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3297 return success();
3298 }
3299};
3300
3301/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3302/// vector.interleave.
3303class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3304public:
3305 using Base::Base;
3306
3307 LogicalResult matchAndRewrite(ShuffleOp op,
3308 PatternRewriter &rewriter) const override {
3309 VectorType resultType = op.getResultVectorType();
3310 if (resultType.isScalable())
3311 return rewriter.notifyMatchFailure(
3312 op, "ShuffleOp can't represent a scalable interleave");
3313
3314 if (resultType.getRank() != 1)
3315 return rewriter.notifyMatchFailure(
3316 op, "ShuffleOp can't represent an n-D interleave");
3317
3318 VectorType sourceType = op.getV1VectorType();
3319 if (sourceType != op.getV2VectorType() ||
3320 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3321 return rewriter.notifyMatchFailure(
3322 op, "ShuffleOp types don't match an interleave");
3323 }
3324
3325 ArrayRef<int64_t> shuffleMask = op.getMask();
3326 int64_t resultVectorSize = resultType.getNumElements();
3327 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3328 int64_t maskValueA = shuffleMask[i * 2];
3329 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3330 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3331 return rewriter.notifyMatchFailure(op,
3332 "ShuffleOp mask not interleaving");
3333 }
3334
3335 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3336 return success();
3337 }
3338};
3339
3340} // namespace
3341
3342void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3343 MLIRContext *context) {
3344 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp>(
3345 context);
3346}
3347
3348//===----------------------------------------------------------------------===//
3349// InsertOp
3350//===----------------------------------------------------------------------===//
3351
3352void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3353 SetIntRangeFn setResultRanges) {
3354 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3355}
3356
3357void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3358 Value source, Value dest) {
3359 auto vectorTy = cast<VectorType>(dest.getType());
3360 build(builder, result, source, dest,
3361 SmallVector<int64_t>(vectorTy.getRank(), 0));
3362}
3363
3364void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3365 Value source, Value dest, int64_t position) {
3366 build(builder, result, source, dest, ArrayRef<int64_t>{position});
3367}
3368
3369void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3370 Value source, Value dest, OpFoldResult position) {
3371 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3372}
3373
3374void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3375 Value source, Value dest,
3376 ArrayRef<int64_t> position) {
3377 SmallVector<OpFoldResult> posVals;
3378 posVals.reserve(position.size());
3379 llvm::transform(position, std::back_inserter(posVals),
3380 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3381 build(builder, result, source, dest, posVals);
3382}
3383
3384void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3385 Value source, Value dest,
3386 ArrayRef<OpFoldResult> position) {
3387 SmallVector<int64_t> staticPos;
3388 SmallVector<Value> dynamicPos;
3389 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3390 build(builder, result, source, dest, dynamicPos,
3391 builder.getDenseI64ArrayAttr(staticPos));
3392}
3393
3394LogicalResult InsertOp::verify() {
3395 if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3396 if (srcTy.getRank() == 0)
3397 return emitError(
3398 "expected a scalar instead of a 0-d vector as the source operand");
3399
3400 SmallVector<OpFoldResult> position = getMixedPosition();
3401 auto destVectorType = getDestVectorType();
3402 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3403 return emitOpError(
3404 "expected position attribute of rank no greater than dest vector rank");
3405 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3406 if (srcVectorType &&
3407 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3408 static_cast<unsigned>(destVectorType.getRank())))
3409 return emitOpError("expected position attribute rank + source rank to "
3410 "match dest vector rank");
3411 if (!srcVectorType &&
3412 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3413 return emitOpError(
3414 "expected position attribute rank to match the dest vector rank");
3415 for (auto [idx, pos] : llvm::enumerate(position)) {
3416 if (auto attr = dyn_cast<Attribute>(pos)) {
3417 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3418 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3419 destVectorType.getDimSize(idx))) {
3420 return emitOpError("expected position attribute #")
3421 << (idx + 1)
3422 << " to be a non-negative integer smaller than the "
3423 "corresponding "
3424 "dest vector dimension";
3425 }
3426 }
3427 }
3428 return success();
3429}
3430
3431// Calculate the linearized position of the continuous chunk of elements to
3432// insert, based on the shape of the value to insert and the positions to insert
3433// at.
3434static int64_t calculateInsertPosition(VectorType destTy,
3435 ArrayRef<int64_t> positions) {
3436 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3437 assert(positions.size() <= completePositions.size() &&
3438 "positions size must be less than or equal to destTy rank");
3439 copy(positions, completePositions.begin());
3440 return linearize(completePositions, computeStrides(destTy.getShape()));
3441}
3442
3443namespace {
3444
3445// If insertOp is only inserting unit dimensions it can be transformed to a
3446// broadcast.
3447class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3448public:
3449 using Base::Base;
3450
3451 LogicalResult matchAndRewrite(InsertOp insertOp,
3452 PatternRewriter &rewriter) const override {
3453 auto srcVecType =
3454 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3455 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3456 srcVecType.getNumElements())
3457 return failure();
3458 rewriter.replaceOpWithNewOp<BroadcastOp>(
3459 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3460 return success();
3461 }
3462};
3463
3464/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3465class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3466public:
3467 using Base::Base;
3468
3469 LogicalResult matchAndRewrite(InsertOp op,
3470 PatternRewriter &rewriter) const override {
3471
3472 Value splat = getScalarSplatSource(op.getValueToStore());
3473 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3474 return failure();
3475
3476 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3477 return success();
3478 }
3479};
3480
3481/// Pattern to optimize a chain of insertions.
3482///
3483/// This pattern identifies chains of vector.insert operations that:
3484/// 1. Only insert values at static positions.
3485/// 2. Completely initialize all elements in the resulting vector.
3486/// 3. All intermediate insert operations have only one use.
3487///
3488/// When these conditions are met, the entire chain can be replaced with a
3489/// single vector.from_elements operation.
3490///
3491/// To keep this pattern simple, and avoid spending too much time on matching
3492/// fragmented insert chains, this pattern only considers the last insert op in
3493/// the chain.
3494///
3495/// Example transformation:
3496/// %poison = ub.poison : vector<2xi32>
3497/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3498/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3499/// ->
3500/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3501class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3502public:
3503 using Base::Base;
3504 LogicalResult matchAndRewrite(InsertOp op,
3505 PatternRewriter &rewriter) const override {
3506
3507 VectorType destTy = op.getDestVectorType();
3508 if (destTy.isScalable())
3509 return failure();
3510 // Ensure this is the trailing vector.insert op in a chain of inserts.
3511 for (Operation *user : op.getResult().getUsers())
3512 if (auto insertOp = dyn_cast<InsertOp>(user))
3513 if (insertOp.getDest() == op.getResult())
3514 return failure();
3515
3516 InsertOp currentOp = op;
3517 SmallVector<InsertOp> chainInsertOps;
3518 while (currentOp) {
3519 // Check cond 1: Dynamic position is not supported.
3520 if (currentOp.hasDynamicPosition())
3521 return failure();
3522
3523 chainInsertOps.push_back(currentOp);
3524 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3525 // Check cond 3: Intermediate inserts have only one use to avoid an
3526 // explosion of vectors.
3527 if (currentOp && !currentOp->hasOneUse())
3528 return failure();
3529 }
3530
3531 int64_t vectorSize = destTy.getNumElements();
3532 int64_t initializedCount = 0;
3533 SmallVector<bool> initializedDestIdxs(vectorSize, false);
3534 SmallVector<int64_t> pendingInsertPos;
3535 SmallVector<int64_t> pendingInsertSize;
3536 SmallVector<Value> pendingInsertValues;
3537
3538 for (auto insertOp : chainInsertOps) {
3539 // This pattern can do nothing with poison index.
3540 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3541 return failure();
3542
3543 // Calculate the linearized position for inserting elements.
3544 int64_t insertBeginPosition =
3545 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3546
3547 // The valueToStore operand may be a vector or a scalar. Need to handle
3548 // both cases.
3549 int64_t insertSize = 1;
3550 if (auto srcVectorType =
3551 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3552 insertSize = srcVectorType.getNumElements();
3553
3554 assert(insertBeginPosition + insertSize <= vectorSize &&
3555 "insert would overflow the vector");
3556
3557 for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3558 insertBeginPosition + insertSize)) {
3559 if (initializedDestIdxs[index])
3560 continue;
3561 initializedDestIdxs[index] = true;
3562 ++initializedCount;
3563 }
3564
3565 // Defer the creation of ops before we can make sure the pattern can
3566 // succeed.
3567 pendingInsertPos.push_back(insertBeginPosition);
3568 pendingInsertSize.push_back(insertSize);
3569 pendingInsertValues.push_back(insertOp.getValueToStore());
3570
3571 if (initializedCount == vectorSize)
3572 break;
3573 }
3574
3575 // Check cond 2: all positions must be initialized.
3576 if (initializedCount != vectorSize)
3577 return failure();
3578
3579 SmallVector<Value> elements(vectorSize);
3580 for (auto [insertBeginPosition, insertSize, valueToStore] :
3581 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3582 pendingInsertValues))) {
3583 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3584
3585 if (!srcVectorType) {
3586 elements[insertBeginPosition] = valueToStore;
3587 continue;
3588 }
3589
3590 SmallVector<Type> elementToInsertTypes(insertSize,
3591 srcVectorType.getElementType());
3592 // Get all elements from the vector in row-major order.
3593 auto elementsToInsert = vector::ToElementsOp::create(
3594 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3595 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3596 elements[insertBeginPosition + linearIdx] =
3597 elementsToInsert.getResult(linearIdx);
3598 }
3599 }
3600
3601 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3602 return success();
3603 }
3604};
3605
3606} // namespace
3607
3608static Attribute
3610 Attribute dstAttr,
3611 int64_t maxVectorSizeFoldThreshold) {
3612 if (insertOp.hasDynamicPosition())
3613 return {};
3614
3615 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3616 if (!denseDst)
3617 return {};
3618
3619 if (!srcAttr) {
3620 return {};
3621 }
3622
3623 VectorType destTy = insertOp.getDestVectorType();
3624 if (destTy.isScalable())
3625 return {};
3626
3627 // Make sure we do not create too many large constants.
3628 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3629 !insertOp->hasOneUse())
3630 return {};
3631
3632 // Calculate the linearized position for inserting elements.
3633 int64_t insertBeginPosition =
3634 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3635 SmallVector<Attribute> insertedValues;
3636 Type destEltType = destTy.getElementType();
3637
3638 /// Converts attribute to the expected type if there's
3639 /// a mismatch.
3640 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3641 for (auto value : denseSource.getValues<Attribute>())
3642 insertedValues.push_back(convertNumericAttr(value, destEltType));
3643 } else {
3644 insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3645 }
3646
3647 auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3648 copy(insertedValues, allValues.begin() + insertBeginPosition);
3649 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3650
3651 return newAttr;
3652}
3653
3654/// Folder to replace the `dest` operand of the insert op with the root dest of
3655/// the insert op use chain.
3656static Value foldInsertUseChain(InsertOp insertOp) {
3657 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3658 if (!destInsert)
3659 return {};
3660
3661 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3662 return {};
3663
3664 insertOp.setOperand(1, destInsert.getDest());
3665 return insertOp.getResult();
3666}
3667
3668void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3669 MLIRContext *context) {
3670 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3671 InsertChainFullyInitialized>(context);
3672}
3673
3674OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
3675 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3676 // unless the source vector constant has a single use.
3677 constexpr int64_t vectorSizeFoldThreshold = 256;
3678 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3679 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3680 // (type mismatch).
3681 if (getNumIndices() == 0 && getValueToStoreType() == getType())
3682 return getValueToStore();
3683 // Fold `arith.constant` indices into the `vector.insert` operation.
3684 // Do not stop here as this fold may enable subsequent folds that require
3685 // constant indices.
3686 SmallVector<Value> operands = {getValueToStore(), getDest()};
3687 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
3688
3689 if (auto res = foldInsertUseChain(*this))
3690 return res;
3691 if (auto res = foldPoisonIndexInsertExtractOp(
3692 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
3693 return res;
3694 if (auto res = foldDenseElementsAttrDestInsertOp(
3695 *this, adaptor.getValueToStore(), adaptor.getDest(),
3696 vectorSizeFoldThreshold)) {
3697 return res;
3698 }
3699
3700 return inplaceFolded;
3701}
3702
3703//===----------------------------------------------------------------------===//
3704// InsertStridedSliceOp
3705//===----------------------------------------------------------------------===//
3706
3707void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
3708 Value source, Value dest,
3709 ArrayRef<int64_t> offsets,
3710 ArrayRef<int64_t> strides) {
3711 result.addOperands({source, dest});
3712 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
3713 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
3714 result.addTypes(dest.getType());
3715 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
3716 offsetsAttr);
3717 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
3718 stridesAttr);
3719}
3720
3721// TODO: Should be moved to Tablegen ConfinedAttr attributes.
3722template <typename OpType>
3723static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
3724 ArrayAttr arrayAttr,
3726 StringRef attrName) {
3727 if (arrayAttr.size() > shape.size())
3728 return op.emitOpError("expected ")
3729 << attrName << " attribute of rank no greater than vector rank";
3730 return success();
3731}
3732
3733// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3734// interval. If `halfOpen` is true then the admissible interval is [min, max).
3735// Otherwise, the admissible interval is [min, max].
3736template <typename OpType>
3737static LogicalResult
3739 int64_t max, StringRef attrName,
3740 bool halfOpen = true) {
3741 for (auto attr : arrayAttr) {
3742 auto val = llvm::cast<IntegerAttr>(attr).getInt();
3743 auto upper = max;
3744 if (!halfOpen)
3745 upper += 1;
3746 if (val < min || val >= upper)
3747 return op.emitOpError("expected ") << attrName << " to be confined to ["
3748 << min << ", " << upper << ")";
3749 }
3750 return success();
3751}
3752
3753// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
3754// interval. If `halfOpen` is true then the admissible interval is [min, max).
3755// Otherwise, the admissible interval is [min, max].
3756template <typename OpType>
3757static LogicalResult
3759 ArrayRef<int64_t> shape, StringRef attrName,
3760 bool halfOpen = true, int64_t min = 0) {
3761 for (auto [index, attrDimPair] :
3762 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
3763 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
3764 int64_t max = std::get<1>(attrDimPair);
3765 if (!halfOpen)
3766 max += 1;
3767 if (val < min || val >= max)
3768 return op.emitOpError("expected ")
3769 << attrName << " dimension " << index << " to be confined to ["
3770 << min << ", " << max << ")";
3771 }
3772 return success();
3773}
3774
3775// Returns true if, for all indices i = 0..shape.size()-1, val is in the
3776// [min, max} interval:
3777// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
3778// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
3779// the admissible interval is [min, max].
3780template <typename OpType>
3782 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
3783 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
3784 bool halfOpen = true, int64_t min = 1) {
3785 assert(arrayAttr1.size() <= shape.size());
3786 assert(arrayAttr2.size() <= shape.size());
3787 for (auto [index, it] :
3788 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
3789 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
3790 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
3791 int64_t max = std::get<2>(it);
3792 if (!halfOpen)
3793 max += 1;
3794 if (val1 + val2 < 0 || val1 + val2 >= max)
3795 return op.emitOpError("expected sum(")
3796 << attrName1 << ", " << attrName2 << ") dimension " << index
3797 << " to be confined to [" << min << ", " << max << ")";
3798 }
3799 return success();
3800}
3801
3803 MLIRContext *context) {
3804 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
3805 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
3806 });
3807 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
3808}
3809
3810LogicalResult InsertStridedSliceOp::verify() {
3811 auto sourceVectorType = getSourceVectorType();
3812 auto destVectorType = getDestVectorType();
3813 auto offsets = getOffsetsAttr();
3814 auto strides = getStridesAttr();
3815 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
3816 return emitOpError(
3817 "expected offsets of same size as destination vector rank");
3818 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
3819 return emitOpError("expected strides of same size as source vector rank");
3820 if (sourceVectorType.getRank() > destVectorType.getRank())
3821 return emitOpError(
3822 "expected source rank to be no greater than destination rank");
3823
3824 auto sourceShape = sourceVectorType.getShape();
3825 auto destShape = destVectorType.getShape();
3826 SmallVector<int64_t, 4> sourceShapeAsDestShape(
3827 destShape.size() - sourceShape.size(), 0);
3828 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
3829 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
3830 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
3831 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
3832 offName)) ||
3833 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
3834 /*max=*/1, stridesName,
3835 /*halfOpen=*/false)) ||
3837 *this, offsets,
3838 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
3839 offName, "source vector shape",
3840 /*halfOpen=*/false, /*min=*/1)))
3841 return failure();
3842
3843 unsigned rankDiff = destShape.size() - sourceShape.size();
3844 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
3845 if (sourceVectorType.getScalableDims()[idx] !=
3846 destVectorType.getScalableDims()[idx + rankDiff]) {
3847 return emitOpError("mismatching scalable flags (at source vector idx=")
3848 << idx << ")";
3849 }
3850 if (sourceVectorType.getScalableDims()[idx]) {
3851 auto sourceSize = sourceShape[idx];
3852 auto destSize = destShape[idx + rankDiff];
3853 if (sourceSize != destSize) {
3854 return emitOpError("expected size at idx=")
3855 << idx
3856 << (" to match the corresponding base size from the input "
3857 "vector (")
3858 << sourceSize << (" vs ") << destSize << (")");
3859 }
3860 }
3861 }
3862
3863 return success();
3864}
3865
3866namespace {
3867/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
3868class FoldInsertStridedSliceSplat final
3869 : public OpRewritePattern<InsertStridedSliceOp> {
3870public:
3871 using Base::Base;
3872
3873 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3874 PatternRewriter &rewriter) const override {
3875
3876 auto dst = insertStridedSliceOp.getDest();
3877 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3878 if (!splat || getScalarSplatSource(dst) != splat)
3879 return failure();
3880
3881 rewriter.replaceOp(insertStridedSliceOp, dst);
3882 return success();
3883 }
3884};
3885
3886/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
3887/// to dst.
3888class FoldInsertStridedSliceOfExtract final
3889 : public OpRewritePattern<InsertStridedSliceOp> {
3890public:
3891 using Base::Base;
3892
3893 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
3894 PatternRewriter &rewriter) const override {
3895 auto extractStridedSliceOp =
3896 insertStridedSliceOp.getValueToStore()
3897 .getDefiningOp<vector::ExtractStridedSliceOp>();
3898
3899 if (!extractStridedSliceOp)
3900 return failure();
3901
3902 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
3903 return failure();
3904
3905 // Check if have the same strides and offsets.
3906 if (extractStridedSliceOp.getStrides() !=
3907 insertStridedSliceOp.getStrides() ||
3908 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
3909 return failure();
3910
3911 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3912 return success();
3913 }
3914};
3915
3916// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
3917// ConstantOp.
3918class InsertStridedSliceConstantFolder final
3919 : public OpRewritePattern<InsertStridedSliceOp> {
3920public:
3921 using Base::Base;
3922
3923 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3924 // unless the source vector constant has a single use.
3925 static constexpr int64_t vectorSizeFoldThreshold = 256;
3926
3927 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
3928 PatternRewriter &rewriter) const override {
3929 // Return if 'InsertOp' operand is not defined by a compatible vector
3930 // ConstantOp.
3931 TypedValue<VectorType> destVector = op.getDest();
3932 Attribute vectorDestCst;
3933 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3934 return failure();
3935
3936 VectorType destTy = destVector.getType();
3937 if (destTy.isScalable())
3938 return failure();
3939
3940 // Make sure we do not create too many large constants.
3941 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3942 !destVector.hasOneUse())
3943 return failure();
3944
3945 TypedValue<VectorType> sourceValue = op.getValueToStore();
3946 Attribute sourceCst;
3947 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3948 return failure();
3949
3950 // TODO: Support poison.
3951 if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
3952 return failure();
3953
3954 // TODO: Handle non-unit strides when they become available.
3955 if (op.hasNonUnitStrides())
3956 return failure();
3957
3958 VectorType sliceVecTy = sourceValue.getType();
3959 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3960 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
3961 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
3962 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
3963
3964 // Calcualte the destination element indices by enumerating all slice
3965 // positions within the destination and linearizing them. The enumeration
3966 // order is lexicographic which yields a sequence of monotonically
3967 // increasing linearized position indices.
3968 // Because the destination may have higher dimensionality then the slice,
3969 // we keep track of two overlapping sets of positions and offsets.
3970 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
3971 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
3972 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
3973 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
3974 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
3975 MutableArrayRef<int64_t> currSlicePosition(
3976 currDestPosition.begin() + rankDifference, currDestPosition.end());
3977 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
3978 offsets.end());
3979 do {
3980 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
3981 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
3982 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
3983 "Invalid slice element");
3984 newValues[linearizedPosition] = *sliceValuesIt;
3985 ++sliceValuesIt;
3986 } while (succeeded(
3987 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
3988
3989 auto newAttr = DenseElementsAttr::get(destTy, newValues);
3990 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3991 return success();
3992 }
3993};
3994
3995} // namespace
3996
3997void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
3998 RewritePatternSet &results, MLIRContext *context) {
3999 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4000 InsertStridedSliceConstantFolder>(context);
4001}
4002
4003OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4004 if (getSourceVectorType() == getDestVectorType())
4005 return getValueToStore();
4006 return {};
4007}
4008
4009//===----------------------------------------------------------------------===//
4010// OuterProductOp
4011//===----------------------------------------------------------------------===//
4012
4013/// Build an op without mask, use the type of `acc` as the return type.
4014void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4015 Value lhs, Value rhs, Value acc) {
4016 result.addOperands({lhs, rhs, acc});
4017 result.addTypes(acc.getType());
4018}
4019
4020void OuterProductOp::print(OpAsmPrinter &p) {
4021 p << " " << getLhs() << ", " << getRhs();
4022 if (getAcc()) {
4023 p << ", " << getAcc();
4024 p.printOptionalAttrDict((*this)->getAttrs());
4025 }
4026 p << " : " << getLhs().getType() << ", " << getRhs().getType();
4027}
4028
4029ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4030 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4031 Type tLHS, tRHS;
4032 if (parser.parseOperandList(operandsInfo) ||
4033 parser.parseOptionalAttrDict(result.attributes) ||
4034 parser.parseColonType(tLHS) || parser.parseComma() ||
4035 parser.parseType(tRHS))
4036 return failure();
4037 if (operandsInfo.size() < 2)
4038 return parser.emitError(parser.getNameLoc(),
4039 "expected at least 2 operands");
4040 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4041 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4042 if (!vLHS)
4043 return parser.emitError(parser.getNameLoc(),
4044 "expected vector type for operand #1");
4045
4046 VectorType resType;
4047 if (vRHS) {
4048 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4049 vRHS.getScalableDims()[0]};
4050 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4051 vLHS.getElementType(), scalableDimsRes);
4052 } else {
4053 // Scalar RHS operand
4054 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4055 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4056 scalableDimsRes);
4057 }
4058
4059 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4060 result.attributes.append(
4061 OuterProductOp::getKindAttrName(result.name),
4062 CombiningKindAttr::get(result.getContext(),
4063 OuterProductOp::getDefaultKind()));
4064 }
4065
4066 return failure(
4067 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4068 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4069 (operandsInfo.size() > 2 &&
4070 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4071 parser.addTypeToList(resType, result.types));
4072}
4073
4074LogicalResult OuterProductOp::verify() {
4075 Type tRHS = getOperandTypeRHS();
4076 VectorType vLHS = getOperandVectorTypeLHS(),
4077 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4078 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4079
4080 if (vLHS.getRank() != 1)
4081 return emitOpError("expected 1-d vector for operand #1");
4082
4083 if (vRHS) {
4084 // Proper OUTER operation.
4085 if (vRHS.getRank() != 1)
4086 return emitOpError("expected 1-d vector for operand #2");
4087 if (vRES.getRank() != 2)
4088 return emitOpError("expected 2-d vector result");
4089 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4090 return emitOpError("expected #1 operand dim to match result dim #1");
4091 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4092 return emitOpError("expected #2 operand dim to match result dim #2");
4093 if (vLHS.isScalable() && !vRHS.isScalable()) {
4094 // This restriction reflects what's currently supported in terms of
4095 // scalable vectors. However, we could relax this if there's a use case.
4096 return emitOpError(
4097 "expected either both or only #2 operand dim to be scalable");
4098 }
4099 } else {
4100 // An AXPY operation.
4101 if (vRES.getRank() != 1)
4102 return emitOpError("expected 1-d vector result");
4103 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4104 return emitOpError("expected #1 operand dim to match result dim #1");
4105 }
4106
4107 if (vACC && vACC != vRES)
4108 return emitOpError("expected operand #3 of same type as result type");
4109
4110 // Verify supported combining kind.
4111 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4112 return emitOpError("unsupported outerproduct type");
4113
4114 return success();
4115}
4116
4117// MaskableOpInterface methods.
4118
4119/// Returns the mask type expected by this operation. Mostly used for
4120/// verification purposes. It requires the operation to be vectorized."
4121Type OuterProductOp::getExpectedMaskType() {
4122 auto vecType = this->getResultVectorType();
4123 return VectorType::get(vecType.getShape(),
4124 IntegerType::get(vecType.getContext(), /*width=*/1),
4125 vecType.getScalableDims());
4126}
4127
4128//===----------------------------------------------------------------------===//
4129// ExtractStridedSliceOp
4130//===----------------------------------------------------------------------===//
4131
4132// Inference works as follows:
4133// 1. Add 'sizes' from prefix of dims in 'offsets'.
4134// 2. Add sizes from 'vectorType' for remaining dims.
4135// Scalable flags are inherited from 'vectorType'.
4136static Type inferStridedSliceOpResultType(VectorType vectorType,
4137 ArrayAttr offsets, ArrayAttr sizes,
4138 ArrayAttr strides) {
4139 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4141 shape.reserve(vectorType.getRank());
4142 unsigned idx = 0;
4143 for (unsigned e = offsets.size(); idx < e; ++idx)
4144 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4145 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4146 shape.push_back(vectorType.getShape()[idx]);
4147
4148 return VectorType::get(shape, vectorType.getElementType(),
4149 vectorType.getScalableDims());
4150}
4151
4152void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4153 Value source, ArrayRef<int64_t> offsets,
4154 ArrayRef<int64_t> sizes,
4155 ArrayRef<int64_t> strides) {
4156 result.addOperands(source);
4157 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4158 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4159 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4160 result.addTypes(
4161 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4162 offsetsAttr, sizesAttr, stridesAttr));
4163 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4164 offsetsAttr);
4165 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4166 sizesAttr);
4167 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4168 stridesAttr);
4169}
4170
4171LogicalResult ExtractStridedSliceOp::verify() {
4172 auto type = getSourceVectorType();
4173 auto offsets = getOffsetsAttr();
4174 auto sizes = getSizesAttr();
4175 auto strides = getStridesAttr();
4176 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4177 return emitOpError(
4178 "expected offsets, sizes and strides attributes of same size");
4179
4180 auto shape = type.getShape();
4181 auto offName = getOffsetsAttrName();
4182 auto sizesName = getSizesAttrName();
4183 auto stridesName = getStridesAttrName();
4184 if (failed(
4185 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4186 failed(
4187 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4188 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4189 stridesName)) ||
4190 failed(
4191 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4192 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4193 /*halfOpen=*/false,
4194 /*min=*/1)) ||
4195 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4196 /*max=*/1, stridesName,
4197 /*halfOpen=*/false)) ||
4198 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4199 shape, offName, sizesName,
4200 /*halfOpen=*/false)))
4201 return failure();
4202
4203 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4204 offsets, sizes, strides);
4205 if (getResult().getType() != resultType)
4206 return emitOpError("expected result type to be ") << resultType;
4207
4208 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4209 if (type.getScalableDims()[idx]) {
4210 auto inputDim = type.getShape()[idx];
4211 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4212 if (inputDim != inputSize)
4213 return emitOpError("expected size at idx=")
4214 << idx
4215 << (" to match the corresponding base size from the input "
4216 "vector (")
4217 << inputSize << (" vs ") << inputDim << (")");
4218 }
4219 }
4220
4221 return success();
4222}
4223
4224// When the source of ExtractStrided comes from a chain of InsertStrided ops try
4225// to use the source of the InsertStrided ops if we can detect that the
4226// extracted vector is a subset of one of the vector inserted.
4227static LogicalResult
4228foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4229 // Helper to extract integer out of ArrayAttr.
4230 auto getElement = [](ArrayAttr array, int idx) {
4231 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4232 };
4233 ArrayAttr extractOffsets = op.getOffsets();
4234 ArrayAttr extractStrides = op.getStrides();
4235 ArrayAttr extractSizes = op.getSizes();
4236 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4237 while (insertOp) {
4238 if (op.getSourceVectorType().getRank() !=
4239 insertOp.getSourceVectorType().getRank())
4240 return failure();
4241 ArrayAttr insertOffsets = insertOp.getOffsets();
4242 ArrayAttr insertStrides = insertOp.getStrides();
4243 // If the rank of extract is greater than the rank of insert, we are likely
4244 // extracting a partial chunk of the vector inserted.
4245 if (extractOffsets.size() > insertOffsets.size())
4246 return failure();
4247 bool patialoverlap = false;
4248 bool disjoint = false;
4249 SmallVector<int64_t, 4> offsetDiffs;
4250 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4251 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4252 return failure();
4253 int64_t start = getElement(insertOffsets, dim);
4254 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4255 int64_t offset = getElement(extractOffsets, dim);
4256 int64_t size = getElement(extractSizes, dim);
4257 // Check if the start of the extract offset is in the interval inserted.
4258 if (start <= offset && offset < end) {
4259 // If the extract interval overlaps but is not fully included we may
4260 // have a partial overlap that will prevent any folding.
4261 if (offset + size > end)
4262 patialoverlap = true;
4263 offsetDiffs.push_back(offset - start);
4264 continue;
4265 }
4266 disjoint = true;
4267 break;
4268 }
4269 // The extract element chunk is a subset of the insert element.
4270 if (!disjoint && !patialoverlap) {
4271 op.setOperand(insertOp.getValueToStore());
4272 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4273 OpBuilder b(op.getContext());
4274 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4275 return success();
4276 }
4277 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4278 // in the insert chain.
4279 if (disjoint)
4280 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4281 else {
4282 // The extracted vector partially overlap the inserted vector, we cannot
4283 // fold.
4284 return failure();
4285 }
4286 }
4287 return failure();
4288}
4289
4290// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4291static OpFoldResult
4293 Attribute foldInput) {
4294
4295 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4296 if (!dense)
4297 return {};
4298
4299 // TODO: Handle non-unit strides when they become available.
4300 if (op.hasNonUnitStrides())
4301 return {};
4302
4303 VectorType sourceVecTy = op.getSourceVectorType();
4304 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4305 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4306
4307 VectorType sliceVecTy = op.getType();
4308 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4309 int64_t rank = sliceVecTy.getRank();
4310
4311 // Expand offsets and sizes to match the vector rank.
4312 SmallVector<int64_t, 4> offsets(rank, 0);
4313 copy(getI64SubArray(op.getOffsets()), offsets.begin());
4314
4315 SmallVector<int64_t, 4> sizes(sourceShape);
4316 copy(getI64SubArray(op.getSizes()), sizes.begin());
4317
4318 // Calculate the slice elements by enumerating all slice positions and
4319 // linearizing them. The enumeration order is lexicographic which yields a
4320 // sequence of monotonically increasing linearized position indices.
4321 const auto denseValuesBegin = dense.value_begin<Attribute>();
4322 SmallVector<Attribute> sliceValues;
4323 sliceValues.reserve(sliceVecTy.getNumElements());
4324 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4325 do {
4326 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4327 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4328 "Invalid index");
4329 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4330 } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4331
4332 assert(static_cast<int64_t>(sliceValues.size()) ==
4333 sliceVecTy.getNumElements() &&
4334 "Invalid number of slice elements");
4335 return DenseElementsAttr::get(sliceVecTy, sliceValues);
4336}
4337
4338OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4339 if (getSourceVectorType() == getResult().getType())
4340 return getSource();
4341 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4342 return getResult();
4343
4344 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4345 if (auto splat =
4346 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4347 return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4348
4349 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4350 return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4351}
4352
4353void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4354 populateFromInt64AttrArray(getOffsets(), results);
4355}
4356
4357namespace {
4358
4359// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4360// CreateMaskOp.
4361//
4362// Example:
4363//
4364// %mask = vector.create_mask %ub : vector<16xi1>
4365// %slice = vector.extract_strided_slice [%offset] [8] [1]
4366//
4367// to
4368//
4369// %new_ub = arith.subi %ub, %offset
4370// %mask = vector.create_mask %new_ub : vector<8xi1>
4371class StridedSliceCreateMaskFolder final
4372 : public OpRewritePattern<ExtractStridedSliceOp> {
4373 using Base::Base;
4374
4375public:
4376 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4377 PatternRewriter &rewriter) const override {
4378 Location loc = extractStridedSliceOp.getLoc();
4379 // Return if 'extractStridedSliceOp' operand is not defined by a
4380 // CreateMaskOp.
4381 auto createMaskOp =
4382 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4383 if (!createMaskOp)
4384 return failure();
4385 // Return if 'extractStridedSliceOp' has non-unit strides.
4386 if (extractStridedSliceOp.hasNonUnitStrides())
4387 return failure();
4388 // Gather constant mask dimension sizes.
4389 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4390 // Gather strided slice offsets and sizes.
4391 SmallVector<int64_t> sliceOffsets;
4392 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4393 sliceOffsets);
4394 SmallVector<int64_t> sliceSizes;
4395 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4396
4397 // Compute slice of vector mask region.
4398 SmallVector<Value> sliceMaskDimSizes;
4399 sliceMaskDimSizes.reserve(maskDimSizes.size());
4400 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4401 // only iterate on the leading dim sizes. The tail accounts for the
4402 // remaining dim sizes.
4403 for (auto [maskDimSize, sliceOffset, sliceSize] :
4404 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4405 // No need to clamp on min/max values, because create_mask has clamping
4406 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4407 // greater than the vector dim size.
4408 IntegerAttr offsetAttr =
4409 rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4410 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4411 Value sliceMaskDimSize =
4412 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4413 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4414 }
4415 // Add unchanged dimensions.
4416 llvm::append_range(
4417 sliceMaskDimSizes,
4418 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4419 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4420 // region.
4421 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4422 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4423 sliceMaskDimSizes);
4424 return success();
4425 }
4426};
4427
4428// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4429// ConstantMaskOp.
4430class StridedSliceConstantMaskFolder final
4431 : public OpRewritePattern<ExtractStridedSliceOp> {
4432public:
4433 using Base::Base;
4434
4435 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4436 PatternRewriter &rewriter) const override {
4437 // Return if 'extractStridedSliceOp' operand is not defined by a
4438 // ConstantMaskOp.
4439 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4440 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4441 if (!constantMaskOp)
4442 return failure();
4443 // Return if 'extractStridedSliceOp' has non-unit strides.
4444 if (extractStridedSliceOp.hasNonUnitStrides())
4445 return failure();
4446 // Gather constant mask dimension sizes.
4447 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4448 // Gather strided slice offsets and sizes.
4449 SmallVector<int64_t> sliceOffsets;
4450 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4451 sliceOffsets);
4452 SmallVector<int64_t> sliceSizes;
4453 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4454
4455 // Compute slice of vector mask region.
4456 SmallVector<int64_t> sliceMaskDimSizes;
4457 sliceMaskDimSizes.reserve(maskDimSizes.size());
4458 for (auto [maskDimSize, sliceOffset, sliceSize] :
4459 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4460 int64_t sliceMaskDimSize = std::max(
4461 static_cast<int64_t>(0),
4462 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4463 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4464 }
4465 // Add unchanged dimensions.
4466 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4467 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4468 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4469 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4470 // region is a conjunction of mask dim intervals).
4471 if (llvm::is_contained(sliceMaskDimSizes, 0))
4472 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4473
4474 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4475 // region.
4476 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4477 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4478 sliceMaskDimSizes);
4479 return success();
4480 }
4481};
4482
4483// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4484// BroadcastOp(ExtractStrideSliceOp).
4485class StridedSliceBroadcast final
4486 : public OpRewritePattern<ExtractStridedSliceOp> {
4487public:
4488 using Base::Base;
4489
4490 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4491 PatternRewriter &rewriter) const override {
4492 auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4493 if (!broadcast)
4494 return failure();
4495 auto srcVecType =
4496 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4497 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4498 auto dstVecType = llvm::cast<VectorType>(op.getType());
4499 unsigned dstRank = dstVecType.getRank();
4500 unsigned rankDiff = dstRank - srcRank;
4501 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4502 // (n -> m with n > m). If they are originally both broadcasted *and*
4503 // sliced, this can be simplified to just broadcasting.
4504 bool needsSlice = false;
4505 for (unsigned i = 0; i < srcRank; i++) {
4506 if (srcVecType.getDimSize(i) != 1 &&
4507 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4508 needsSlice = true;
4509 break;
4510 }
4511 }
4512 Value source = broadcast.getSource();
4513 if (needsSlice) {
4514 SmallVector<int64_t> offsets =
4515 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4516 SmallVector<int64_t> sizes =
4517 getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4518 for (unsigned i = 0; i < srcRank; i++) {
4519 if (srcVecType.getDimSize(i) == 1) {
4520 // In case this dimension was broadcasted *and* sliced, the offset
4521 // and size need to be updated now that there is no broadcast before
4522 // the slice.
4523 offsets[i] = 0;
4524 sizes[i] = 1;
4525 }
4526 }
4527 source = ExtractStridedSliceOp::create(
4528 rewriter, op->getLoc(), source, offsets, sizes,
4529 getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4530 }
4531 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4532 return success();
4533 }
4534};
4535
4536/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4537class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4538public:
4539 using Base::Base;
4540
4541 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4542 PatternRewriter &rewriter) const override {
4543
4544 Value splat = getScalarSplatSource(op.getSource());
4545 if (!splat)
4546 return failure();
4547 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4548 return success();
4549 }
4550};
4551
4552/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4553/// slice is contiguous, into extract and shape_cast.
4554///
4555/// Example:
4556/// Before:
4557/// %1 = vector.extract_strided_slice %arg0 {
4558/// offsets = [0, 0, 0, 0, 0],
4559/// sizes = [1, 1, 1, 1, 8],
4560/// strides = [1, 1, 1, 1, 1]
4561/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4562/// After:
4563/// %0 = vector.extract %arg0[0, 0, 0, 0]
4564/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4565/// %1 = vector.shape_cast %0
4566/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4567///
4568class ContiguousExtractStridedSliceToExtract final
4569 : public OpRewritePattern<ExtractStridedSliceOp> {
4570public:
4571 using Base::Base;
4572
4573 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4574 PatternRewriter &rewriter) const override {
4575 if (op.hasNonUnitStrides())
4576 return failure();
4577 Value source = op.getOperand();
4578 auto sourceType = cast<VectorType>(source.getType());
4579 if (sourceType.isScalable() || sourceType.getRank() == 0)
4580 return failure();
4581
4582 // Compute the number of offsets to pass to ExtractOp::build. That is the
4583 // difference between the source rank and the desired slice rank. We walk
4584 // the dimensions from innermost out, and stop when the next slice dimension
4585 // is not full-size.
4586 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4587 int numOffsets;
4588 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4589 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4590 break;
4591 }
4592
4593 // If the created extract op would have no offsets, then this whole
4594 // extract_strided_slice is the identity and should have been handled by
4595 // other canonicalizations.
4596 if (numOffsets == 0)
4597 return failure();
4598
4599 // If not even the inner-most dimension is full-size, this op can't be
4600 // rewritten as an ExtractOp.
4601 if (numOffsets == sourceType.getRank() &&
4602 static_cast<int>(sizes.size()) == sourceType.getRank())
4603 return failure();
4604
4605 // The outer dimensions must have unit size.
4606 for (int i = 0; i < numOffsets; ++i) {
4607 if (sizes[i] != 1)
4608 return failure();
4609 }
4610
4611 // Avoid generating slices that have leading unit dimensions. The shape_cast
4612 // op that we create below would take bad generic fallback patterns
4613 // (ShapeCastOpRewritePattern).
4614 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
4615 sizes[numOffsets] == 1) {
4616 ++numOffsets;
4617 }
4618
4619 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
4620 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
4621 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
4622 extractOffsets);
4623 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
4624 return success();
4625 }
4626};
4627
4628} // namespace
4629
4630void ExtractStridedSliceOp::getCanonicalizationPatterns(
4631 RewritePatternSet &results, MLIRContext *context) {
4632 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
4633 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4634 results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
4635 StridedSliceBroadcast, StridedSliceSplat,
4636 ContiguousExtractStridedSliceToExtract>(context);
4637}
4638
4639//===----------------------------------------------------------------------===//
4640// TransferReadOp
4641//===----------------------------------------------------------------------===//
4642
4643/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
4644void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4645 VectorType vectorType, Value source,
4646 ValueRange indices, std::optional<Value> padding,
4647 AffineMapAttr permutationMapAttr,
4648 /*optional*/ ArrayAttr inBoundsAttr) {
4649
4650 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4651 if (!padding)
4652 padding = ub::PoisonOp::create(builder, result.location, elemType);
4653 build(builder, result, vectorType, source, indices, permutationMapAttr,
4654 *padding, /*mask=*/Value(), inBoundsAttr);
4655}
4656
4657/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
4658void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4659 VectorType vectorType, Value source,
4660 ValueRange indices, std::optional<Value> padding,
4661 AffineMap permutationMap,
4662 std::optional<ArrayRef<bool>> inBounds) {
4663 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4664 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4665 ? builder.getBoolArrayAttr(inBounds.value())
4666 : builder.getBoolArrayAttr(
4667 SmallVector<bool>(vectorType.getRank(), false));
4668 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4669 if (!padding)
4670 padding = ub::PoisonOp::create(builder, result.location, elemType);
4671 build(builder, result, vectorType, source, indices, *padding,
4672 permutationMapAttr, inBoundsAttr);
4673}
4674
4675/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
4676void TransferReadOp::build(OpBuilder &builder, OperationState &result,
4677 VectorType vectorType, Value source,
4678 ValueRange indices, std::optional<Value> padding,
4679 std::optional<ArrayRef<bool>> inBounds) {
4680 AffineMap permutationMap = getTransferMinorIdentityMap(
4681 llvm::cast<ShapedType>(source.getType()), vectorType);
4682 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
4683 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
4684 ? builder.getBoolArrayAttr(inBounds.value())
4685 : builder.getBoolArrayAttr(
4686 SmallVector<bool>(vectorType.getRank(), false));
4687 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4688 if (!padding)
4689 padding = ub::PoisonOp::create(builder, result.location, elemType);
4690 build(builder, result, vectorType, source, indices, permutationMapAttr,
4691 *padding,
4692 /*mask=*/Value(), inBoundsAttr);
4693}
4694
4695template <typename EmitFun>
4696static LogicalResult verifyPermutationMap(AffineMap permutationMap,
4697 EmitFun emitOpError) {
4698 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
4699 for (auto expr : permutationMap.getResults()) {
4700 auto dim = dyn_cast<AffineDimExpr>(expr);
4701 auto zero = dyn_cast<AffineConstantExpr>(expr);
4702 if (zero) {
4703 if (zero.getValue() != 0) {
4704 return emitOpError(
4705 "requires a projected permutation_map (at most one dim or the zero "
4706 "constant can appear in each result)");
4707 }
4708 continue;
4709 }
4710 if (!dim) {
4711 return emitOpError("requires a projected permutation_map (at most one "
4712 "dim or the zero constant can appear in each result)");
4713 }
4714 if (seen[dim.getPosition()]) {
4715 return emitOpError(
4716 "requires a permutation_map that is a permutation (found one dim "
4717 "used more than once)");
4718 }
4719 seen[dim.getPosition()] = true;
4720 }
4721 return success();
4722}
4723
4724static LogicalResult
4725verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
4726 VectorType vectorType, VectorType maskType,
4727 VectorType inferredMaskType, AffineMap permutationMap,
4728 ArrayAttr inBounds) {
4729 if (op->hasAttr("masked")) {
4730 return op->emitOpError("masked attribute has been removed. "
4731 "Use in_bounds instead.");
4732 }
4733
4734 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
4735 return op->emitOpError(
4736 "requires source to be a memref or ranked tensor type");
4737
4738 auto elementType = shapedType.getElementType();
4739 DataLayout dataLayout = DataLayout::closest(op);
4740 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
4741 // Memref or tensor has vector element type.
4742 unsigned sourceVecSize =
4743 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
4744 vectorElementType.getShape().back();
4745 unsigned resultVecSize =
4746 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
4747 vectorType.getShape().back();
4748 if (resultVecSize % sourceVecSize != 0)
4749 return op->emitOpError(
4750 "requires the bitwidth of the minor 1-D vector to be an integral "
4751 "multiple of the bitwidth of the minor 1-D vector of the source");
4752
4753 unsigned sourceVecEltRank = vectorElementType.getRank();
4754 unsigned resultVecRank = vectorType.getRank();
4755 if (sourceVecEltRank > resultVecRank)
4756 return op->emitOpError(
4757 "requires source vector element and vector result ranks to match.");
4758 unsigned rankOffset = resultVecRank - sourceVecEltRank;
4759 // Check that permutation map results match 'rankOffset' of vector type.
4760 if (permutationMap.getNumResults() != rankOffset)
4761 return op->emitOpError("requires a permutation_map with result dims of "
4762 "the same rank as the vector type");
4763
4764 if (maskType)
4765 return op->emitOpError("does not support masks with vector element type");
4766 } else {
4767 // Memref or tensor has scalar element type.
4768 unsigned minorSize =
4769 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
4770 unsigned resultVecSize =
4771 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
4772 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
4773 return op->emitOpError(
4774 "requires the bitwidth of the minor 1-D vector to be an integral "
4775 "multiple of the bitwidth of the source element type");
4776
4777 // Check that permutation map results match rank of vector type.
4778 if (permutationMap.getNumResults() != vectorType.getRank())
4779 return op->emitOpError("requires a permutation_map with result dims of "
4780 "the same rank as the vector type");
4781 }
4782
4783 if (permutationMap.getNumSymbols() != 0)
4784 return op->emitOpError("requires permutation_map without symbols");
4785
4786 if (permutationMap.getNumInputs() != shapedType.getRank())
4787 return op->emitOpError("requires a permutation_map with input dims of the "
4788 "same rank as the source type");
4789
4790 if (maskType && maskType != inferredMaskType)
4791 return op->emitOpError("inferred mask type (")
4792 << inferredMaskType << ") and mask operand type (" << maskType
4793 << ") don't match";
4794
4795 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
4796 return op->emitOpError("expects the in_bounds attr of same rank "
4797 "as permutation_map results: ")
4798 << AffineMapAttr::get(permutationMap)
4799 << " vs inBounds of size: " << inBounds.size();
4800
4801 return success();
4802}
4803
4804static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
4805 SmallVector<StringRef, 3> elidedAttrs;
4806 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
4807 if (op.getPermutationMap().isMinorIdentity())
4808 elidedAttrs.push_back(op.getPermutationMapAttrName());
4809 // Elide in_bounds attribute if all dims are out-of-bounds.
4810 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
4811 elidedAttrs.push_back(op.getInBoundsAttrName());
4812 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
4813}
4814
4815void TransferReadOp::print(OpAsmPrinter &p) {
4816 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
4817 if (getMask())
4818 p << ", " << getMask();
4819 printTransferAttrs(p, *this);
4820 p << " : " << getShapedType() << ", " << getVectorType();
4821}
4822
4823VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
4824 AffineMap permMap) {
4825 auto i1Type = IntegerType::get(permMap.getContext(), 1);
4826 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
4827 assert(invPermMap && "Inversed permutation map couldn't be computed");
4828 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
4829
4830 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
4831 // 0-D mask into a single-element 1-D mask.
4832 if (maskShape.empty())
4833 maskShape.push_back(1);
4834
4835 SmallVector<bool> scalableDims =
4836 applyPermutationMap(invPermMap, vecType.getScalableDims());
4837
4838 return VectorType::get(maskShape, i1Type, scalableDims);
4839}
4840
4841ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
4842 auto &builder = parser.getBuilder();
4843 SMLoc typesLoc;
4849 // Parsing with support for paddingValue.
4850 if (parser.parseOperand(sourceInfo) ||
4852 parser.parseComma() || parser.parseOperand(paddingInfo))
4853 return failure();
4854 ParseResult hasMask = parser.parseOptionalComma();
4855 if (hasMask.succeeded()) {
4856 if (parser.parseOperand(maskInfo))
4857 return failure();
4858 }
4859 if (parser.parseOptionalAttrDict(result.attributes) ||
4860 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
4861 return failure();
4862 if (types.size() != 2)
4863 return parser.emitError(typesLoc, "requires two types");
4864 auto indexType = builder.getIndexType();
4865 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
4866 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
4867 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
4868 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
4869 if (!vectorType)
4870 return parser.emitError(typesLoc, "requires vector type");
4871 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
4872 Attribute permMapAttr = result.attributes.get(permMapAttrName);
4873 AffineMap permMap;
4874 if (!permMapAttr) {
4875 if (shapedType.getRank() <
4876 getEffectiveVectorRankForXferOp(shapedType, vectorType))
4877 return parser.emitError(typesLoc,
4878 "expected a custom permutation_map when "
4879 "rank(source) != rank(destination)");
4880 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
4881 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
4882 } else {
4883 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
4884 }
4885 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
4886 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
4887 if (!inBoundsAttr) {
4888 result.addAttribute(inBoundsAttrName,
4889 builder.getBoolArrayAttr(
4890 SmallVector<bool>(permMap.getNumResults(), false)));
4891 }
4892 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
4893 parser.resolveOperands(indexInfo, indexType, result.operands) ||
4894 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
4895 result.operands))
4896 return failure();
4897 if (hasMask.succeeded()) {
4898 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
4899 return parser.emitError(
4900 maskInfo.location, "does not support masks with vector element type");
4901 if (vectorType.getRank() != permMap.getNumResults()) {
4902 return parser.emitError(typesLoc,
4903 "expected the same rank for the vector and the "
4904 "results of the permutation map");
4905 }
4906 // Instead of adding the mask type as an op type, compute it based on the
4907 // vector type and the permutation map (to keep the type signature small).
4908 auto maskType = inferTransferOpMaskType(vectorType, permMap);
4909 if (parser.resolveOperand(maskInfo, maskType, result.operands))
4910 return failure();
4911 }
4912 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
4913 builder.getDenseI32ArrayAttr(
4914 {1, static_cast<int32_t>(indexInfo.size()), 1,
4915 static_cast<int32_t>(hasMask.succeeded())}));
4916 return parser.addTypeToList(vectorType, result.types);
4917}
4918
4919LogicalResult TransferReadOp::verify() {
4920 // Consistency of elemental types in source and vector.
4921 ShapedType shapedType = getShapedType();
4922 VectorType vectorType = getVectorType();
4923 VectorType maskType = getMaskType();
4924 auto paddingType = getPadding().getType();
4925 auto permutationMap = getPermutationMap();
4926 VectorType inferredMaskType =
4927 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
4928 : VectorType();
4929 auto sourceElementType = shapedType.getElementType();
4930
4931 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
4932 return emitOpError("requires ") << shapedType.getRank() << " indices";
4933
4934 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
4935 shapedType, vectorType, maskType,
4936 inferredMaskType, permutationMap, getInBounds())))
4937 return failure();
4938
4939 if (auto sourceVectorElementType =
4940 llvm::dyn_cast<VectorType>(sourceElementType)) {
4941 // Source has vector element type.
4942 // Check that 'sourceVectorElementType' and 'paddingType' types match.
4943 if (sourceVectorElementType != paddingType)
4944 return emitOpError(
4945 "requires source element type and padding type to match.");
4946
4947 } else {
4948 // Check that 'paddingType' is valid to store in a vector type.
4949 if (!VectorType::isValidElementType(paddingType))
4950 return emitOpError("requires valid padding vector elemental type");
4951
4952 // Check that padding type and vector element types match.
4953 if (paddingType != sourceElementType)
4954 return emitOpError(
4955 "requires formal padding and source of the same elemental type");
4956 }
4957
4958 return verifyPermutationMap(permutationMap,
4959 [&](Twine t) { return emitOpError(t); });
4960}
4961
4962// MaskableOpInterface methods.
4963
4964/// Returns the mask type expected by this operation. Mostly used for
4965/// verification purposes. It requires the operation to be vectorized."
4966Type TransferReadOp::getExpectedMaskType() {
4967 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
4968}
4969
4970//===----------------------------------------------------------------------===//
4971// TransferReadOp: VectorTransferOpInterface methods.
4972//===----------------------------------------------------------------------===//
4973VectorType TransferReadOp::getVectorType() {
4974 return cast<VectorType>(getVector().getType());
4975}
4976
4977template <typename TransferOp>
4978static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
4979 // TODO: support more aggressive createOrFold on:
4980 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
4981 if (op.getShapedType().isDynamicDim(indicesIdx))
4982 return false;
4983 Value index = op.getIndices()[indicesIdx];
4984 std::optional<int64_t> cstOp = getConstantIntValue(index);
4985 if (!cstOp.has_value())
4986 return false;
4987
4988 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
4989 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
4990
4991 return cstOp.value() + vectorSize <= sourceSize;
4992}
4993
4994template <typename TransferOp>
4995static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
4996 // TODO: support 0-d corner case.
4997 // TODO: Be less conservative.
4998 if (op.getTransferRank() == 0)
4999 return failure();
5000 AffineMap permutationMap = op.getPermutationMap();
5001 bool changed = false;
5002 SmallVector<bool, 4> newInBounds;
5003 newInBounds.reserve(op.getTransferRank());
5004 // Idxs of non-bcast dims - used when analysing bcast dims.
5005 SmallVector<unsigned> nonBcastDims;
5006
5007 // 1. Process non-broadcast dims
5008 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5009 // 1.1. Already marked as in-bounds, nothing to see here.
5010 if (op.isDimInBounds(i)) {
5011 newInBounds.push_back(true);
5012 continue;
5013 }
5014 // 1.2. Currently out-of-bounds, check whether we can statically determine
5015 // it is inBounds.
5016 bool inBounds = false;
5017 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5018 if (dimExpr) {
5019 inBounds = isInBounds(op, /*resultIdx=*/i,
5020 /*indicesIdx=*/dimExpr.getPosition());
5021 nonBcastDims.push_back(i);
5022 }
5023
5024 newInBounds.push_back(inBounds);
5025 // We commit the pattern if it is "more inbounds".
5026 changed |= inBounds;
5027 }
5028
5029 // 2. Handle broadcast dims
5030 // If all non-broadcast dims are "in bounds", then all bcast dims should be
5031 // "in bounds" as well.
5032 bool allNonBcastDimsInBounds = llvm::all_of(
5033 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5034 if (allNonBcastDimsInBounds) {
5035 for (size_t idx : permutationMap.getBroadcastDims()) {
5036 changed |= !newInBounds[idx];
5037 newInBounds[idx] = true;
5038 }
5039 }
5040
5041 if (!changed)
5042 return failure();
5043 // OpBuilder is only used as a helper to build an I64ArrayAttr.
5044 OpBuilder b(op.getContext());
5045 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5046 return success();
5047}
5048
5049template <typename TransferOp>
5050static LogicalResult foldTransferFullMask(TransferOp op) {
5051 auto mask = op.getMask();
5052 if (!mask)
5053 return failure();
5054
5056 return failure();
5057
5058 op.getMaskMutable().clear();
5059 return success();
5060}
5061
5062/// ```
5063/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5064/// : vector<1x4xf32>, tensor<4x4xf32>
5065/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5066/// : tensor<4x4xf32>, vector<1x4xf32>
5067/// ```
5068/// -> Folds into
5069/// ```
5070/// %v0
5071/// ```
5072static Value foldRAW(TransferReadOp readOp) {
5073 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5074 return {};
5075 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5076 while (defWrite) {
5077 if (checkSameValueRAW(defWrite, readOp))
5078 return defWrite.getVector();
5080 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5081 cast<VectorTransferOpInterface>(readOp.getOperation())))
5082 break;
5083 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5084 }
5085 return {};
5086}
5087
5088OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5089 if (Value vec = foldRAW(*this))
5090 return vec;
5091 /// transfer_read(memrefcast) -> transfer_read
5092 if (succeeded(foldTransferInBoundsAttribute(*this)))
5093 return getResult();
5094 if (succeeded(foldTransferFullMask(*this)))
5095 return getResult();
5096 if (succeeded(memref::foldMemRefCast(*this)))
5097 return getResult();
5098 if (succeeded(tensor::foldTensorCast(*this)))
5099 return getResult();
5100 return OpFoldResult();
5101}
5102
5103std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5104 return llvm::to_vector<4>(getVectorType().getShape());
5105}
5106
5107void TransferReadOp::getEffects(
5108 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5109 &effects) {
5110 if (llvm::isa<MemRefType>(getShapedType()))
5111 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5112 SideEffects::DefaultResource::get());
5113}
5114
5115Speculation::Speculatability TransferReadOp::getSpeculatability() {
5116 if (hasPureTensorSemantics())
5119}
5120
5121namespace {
5122/// Store to load forwarding for transfer operations with permuation maps.
5123/// Even if the permutation maps are different we can still propagate the store
5124/// into the load if the size of the dimensions read and written match. Then we
5125/// can replace the transfer_read + transfer_write by vector.broadcast and
5126/// vector.transpose.
5127/// Example:
5128/// ```
5129/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5130/// {in_bounds = [true, true],
5131/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5132/// vector<4x1xf32>, tensor<4x4x4xf32>
5133/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5134/// {in_bounds = [true, true, true, true],
5135/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5136/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5137/// ```
5138/// To:
5139/// ```
5140/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5141/// %r = vector.transpose %0, [3, 0, 2, 1] :
5142/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5143/// ```
5144struct TransferReadAfterWriteToBroadcast
5145 : public OpRewritePattern<TransferReadOp> {
5146 using Base::Base;
5147
5148 LogicalResult matchAndRewrite(TransferReadOp readOp,
5149 PatternRewriter &rewriter) const override {
5150 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5151 if (!defWrite)
5152 return failure();
5153 // Bail if we need an alias analysis.
5154 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5155 return failure();
5156 // Bail if we need a bounds analysis.
5157 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5158 return failure();
5159 // TODO: If the written transfer chunk is a superset of the read transfer
5160 // chunk we could do an extract_strided_slice.
5161 if (readOp.getTransferChunkAccessed() !=
5162 defWrite.getTransferChunkAccessed())
5163 return failure();
5164 // TODO: Support cases where a dim is explicitly written but implicitly
5165 // read (i.e., a unit dim that is rank reduced).
5166 if (getUnusedDimsBitVector({readOp.getPermutationMap()}) !=
5167 getUnusedDimsBitVector({defWrite.getPermutationMap()}))
5168 return failure();
5169 // This pattern should only catch the broadcast case, the non-broadcast case
5170 // should be done separately to keep application conditions clean and
5171 // separate.
5172 AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
5173 AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
5174 bool bcast = !readMap.getBroadcastDims().empty() ||
5175 !writeMap.getBroadcastDims().empty();
5176 if (!bcast)
5177 return failure();
5178 // At this point, we know we have a bcast.
5179 // Bail in the masked case (too complex atm and needed to properly account
5180 // for padding).
5181 if (readOp.getMask() || defWrite.getMask())
5182 return failure();
5183 // If indices are not the same a shift may be required, bail.
5184 if (readOp.getIndices() != defWrite.getIndices())
5185 return failure();
5186
5187 Value vec = defWrite.getVector();
5188 // TODO: loop through the chain of transfer_write if we can prove that they
5189 // don't overlap with the transfer_read. This requires improving
5190 // `isDisjointTransferIndices` helper.
5191 AffineMap map = readMap.compose(writeMap);
5192 if (map.getNumResults() == 0)
5193 return failure();
5194 // Calculate the permutation to apply to go from the vector stored to the
5195 // vector read.
5196 SmallVector<unsigned> permutation;
5198 return failure();
5199
5200 Location loc = readOp.getLoc();
5201 // Calculate the broadcast shape by applying the reverse permutation to the
5202 // final shape we want.
5203 ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
5204 SmallVector<int64_t> broadcastShape(destShape.size());
5205 SmallVector<bool> broadcastScalableFlags(destShape.size());
5206 for (const auto &pos : llvm::enumerate(permutation)) {
5207 broadcastShape[pos.value()] = destShape[pos.index()];
5208 broadcastScalableFlags[pos.value()] =
5209 readOp.getVectorType().getScalableDims()[pos.index()];
5210 }
5211 VectorType broadcastedType = VectorType::get(
5212 broadcastShape, defWrite.getVectorType().getElementType(),
5213 broadcastScalableFlags);
5214 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedType, vec);
5215 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
5216 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
5217 transposePerm);
5218 return success();
5219 }
5220};
5221} // namespace
5222
5223void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5224 MLIRContext *context) {
5225 results.add<TransferReadAfterWriteToBroadcast>(context);
5226}
5227
5228FailureOr<std::optional<SmallVector<Value>>>
5229TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5230 if (!hasPureBufferSemantics())
5231 return failure();
5233 getResult());
5234}
5235
5236//===----------------------------------------------------------------------===//
5237// TransferWriteOp
5238//===----------------------------------------------------------------------===//
5239
5240/// 1. Builder with type inference.
5241void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5242 Value vector, Value dest, ValueRange indices,
5243 AffineMapAttr permutationMapAttr,
5244 /*optional*/ Value mask,
5245 /*optional*/ ArrayAttr inBoundsAttr) {
5246 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5247 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5248 mask, inBoundsAttr);
5249}
5250
5251/// 2. Builder with type inference that sets an empty mask (variant with attrs).
5252void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5253 Value vector, Value dest, ValueRange indices,
5254 AffineMapAttr permutationMapAttr,
5255 /*optional*/ ArrayAttr inBoundsAttr) {
5256 build(builder, result, vector, dest, indices, permutationMapAttr,
5257 /*mask=*/Value(), inBoundsAttr);
5258}
5259
5260/// 3. Builder with type inference that sets an empty mask (variant without
5261/// attrs)
5262void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5263 Value vector, Value dest, ValueRange indices,
5264 AffineMap permutationMap,
5265 std::optional<ArrayRef<bool>> inBounds) {
5266 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5267 auto inBoundsAttr =
5268 (inBounds && !inBounds.value().empty())
5269 ? builder.getBoolArrayAttr(inBounds.value())
5270 : builder.getBoolArrayAttr(SmallVector<bool>(
5271 llvm::cast<VectorType>(vector.getType()).getRank(), false));
5272 build(builder, result, vector, dest, indices, permutationMapAttr,
5273 /*mask=*/Value(), inBoundsAttr);
5274}
5275
5276/// 4. Builder with type inference that sets an empty mask and sets permutation
5277/// map to 'getMinorIdentityMap'.
5278void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5279 Value vector, Value dest, ValueRange indices,
5280 std::optional<ArrayRef<bool>> inBounds) {
5281 auto vectorType = llvm::cast<VectorType>(vector.getType());
5282 AffineMap permutationMap = getTransferMinorIdentityMap(
5283 llvm::cast<ShapedType>(dest.getType()), vectorType);
5284 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5285}
5286
5287ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5288 OperationState &result) {
5289 auto &builder = parser.getBuilder();
5290 SMLoc typesLoc;
5291 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5292 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5293 SmallVector<Type, 2> types;
5294 OpAsmParser::UnresolvedOperand maskInfo;
5295 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5296 parser.parseOperand(sourceInfo) ||
5297 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5298 return failure();
5299 ParseResult hasMask = parser.parseOptionalComma();
5300 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5301 return failure();
5302 if (parser.parseOptionalAttrDict(result.attributes) ||
5303 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5304 return failure();
5305 if (types.size() != 2)
5306 return parser.emitError(typesLoc, "requires two types");
5307 auto indexType = builder.getIndexType();
5308 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5309 if (!vectorType)
5310 return parser.emitError(typesLoc, "requires vector type");
5311 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5312 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5313 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5314 auto permMapAttrName =
5315 TransferWriteOp::getPermutationMapAttrName(result.name);
5316 auto permMapAttr = result.attributes.get(permMapAttrName);
5317 AffineMap permMap;
5318 if (!permMapAttr) {
5319 if (shapedType.getRank() <
5320 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5321 return parser.emitError(typesLoc,
5322 "expected a custom permutation_map when "
5323 "rank(source) != rank(destination)");
5324 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5325 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5326 } else {
5327 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5328 }
5329 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5330 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5331 if (!inBoundsAttr) {
5332 result.addAttribute(inBoundsAttrName,
5333 builder.getBoolArrayAttr(
5334 SmallVector<bool>(permMap.getNumResults(), false)));
5335 }
5336 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5337 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5338 parser.resolveOperands(indexInfo, indexType, result.operands))
5339 return failure();
5340 if (hasMask.succeeded()) {
5341 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5342 return parser.emitError(
5343 maskInfo.location, "does not support masks with vector element type");
5344 if (vectorType.getRank() != permMap.getNumResults()) {
5345 return parser.emitError(typesLoc,
5346 "expected the same rank for the vector and the "
5347 "results of the permutation map");
5348 }
5349 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5350 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5351 return failure();
5352 }
5353 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5354 builder.getDenseI32ArrayAttr(
5355 {1, 1, static_cast<int32_t>(indexInfo.size()),
5356 static_cast<int32_t>(hasMask.succeeded())}));
5357 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5358 parser.addTypeToList(shapedType, result.types));
5359}
5360
5361void TransferWriteOp::print(OpAsmPrinter &p) {
5362 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5363 if (getMask())
5364 p << ", " << getMask();
5365 printTransferAttrs(p, *this);
5366 p << " : " << getVectorType() << ", " << getShapedType();
5367}
5368
5369LogicalResult TransferWriteOp::verify() {
5370 // Consistency of elemental types in shape and vector.
5371 ShapedType shapedType = getShapedType();
5372 VectorType vectorType = getVectorType();
5373 VectorType maskType = getMaskType();
5374 auto permutationMap = getPermutationMap();
5375 VectorType inferredMaskType =
5376 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5377 : VectorType();
5378
5379 if (llvm::size(getIndices()) != shapedType.getRank())
5380 return emitOpError("requires ") << shapedType.getRank() << " indices";
5381
5382 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5383 // as the semantics is unclear. This can be revisited later if necessary.
5384 if (hasBroadcastDim())
5385 return emitOpError("should not have broadcast dimensions");
5386
5387 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5388 shapedType, vectorType, maskType,
5389 inferredMaskType, permutationMap, getInBounds())))
5390 return failure();
5391
5392 return verifyPermutationMap(permutationMap,
5393 [&](Twine t) { return emitOpError(t); });
5394}
5395
5396//===----------------------------------------------------------------------===//
5397// TransferWriteOp: MaskableOpInterface methods.
5398//===----------------------------------------------------------------------===//
5399
5400/// Returns the mask type expected by this operation. Mostly used for
5401/// verification purposes.
5402Type TransferWriteOp::getExpectedMaskType() {
5403 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5404}
5405
5406//===----------------------------------------------------------------------===//
5407// TransferWriteOp: VectorTransferOpInterface methods.
5408//===----------------------------------------------------------------------===//
5409Value TransferWriteOp::getVector() { return getOperand(0); }
5410VectorType TransferWriteOp::getVectorType() {
5411 return cast<VectorType>(getValueToStore().getType());
5412}
5413
5414//===----------------------------------------------------------------------===//
5415// TransferWriteOp: fold methods.
5416//===----------------------------------------------------------------------===//
5417/// Fold:
5418/// ```
5419/// %t1 = ...
5420/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5421/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5422/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5423/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5424/// ```
5425///
5426/// into:
5427///
5428/// ```
5429/// %t0
5430/// ```
5431///
5432/// The producer of t1 may or may not be DCE'd depending on whether it is a
5433/// block argument or has side effects.
5434static LogicalResult foldReadInitWrite(TransferWriteOp write,
5435 ArrayRef<Attribute>,
5436 SmallVectorImpl<OpFoldResult> &results) {
5437 // TODO: support 0-d corner case.
5438 if (write.getTransferRank() == 0)
5439 return failure();
5440 auto rankedTensorType =
5441 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5442 // If not operating on tensors, bail.
5443 if (!rankedTensorType)
5444 return failure();
5445 // If no read, bail.
5446 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5447 if (!read)
5448 return failure();
5449 // TODO: support 0-d corner case.
5450 if (read.getTransferRank() == 0)
5451 return failure();
5452 // For now, only accept minor identity. Future: composition is minor identity.
5453 if (!read.getPermutationMap().isMinorIdentity() ||
5454 !write.getPermutationMap().isMinorIdentity())
5455 return failure();
5456 // Bail on mismatching ranks.
5457 if (read.getTransferRank() != write.getTransferRank())
5458 return failure();
5459 // Bail on potential out-of-bounds accesses.
5460 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5461 return failure();
5462 // Tensor types must be the same.
5463 if (read.getBase().getType() != rankedTensorType)
5464 return failure();
5465 // Vector types must be the same.
5466 if (read.getVectorType() != write.getVectorType())
5467 return failure();
5468 // Vector and Tensor shapes must match.
5469 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5470 return failure();
5471 // If any index is nonzero.
5472 auto isNotConstantZero = [](Value v) {
5473 auto cstOp = getConstantIntValue(v);
5474 return !cstOp.has_value() || cstOp.value() != 0;
5475 };
5476 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5477 llvm::any_of(write.getIndices(), isNotConstantZero))
5478 return failure();
5479 // Success.
5480 results.push_back(read.getBase());
5481 return success();
5482}
5483
5484static bool checkSameValueWAR(vector::TransferReadOp read,
5485 vector::TransferWriteOp write) {
5486 return read.getBase() == write.getBase() &&
5487 read.getIndices() == write.getIndices() &&
5488 read.getPermutationMap() == write.getPermutationMap() &&
5489 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5490 !write.getMask();
5491}
5492/// Fold transfer_write write after read:
5493/// ```
5494/// %t0 = ...
5495/// %v = vector.transfer_read %t0[%c0...] :
5496/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5497/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5498/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5499/// ```
5500///
5501/// into:
5502///
5503/// ```
5504/// %t0
5505/// ```
5506static LogicalResult foldWAR(TransferWriteOp write,
5507 SmallVectorImpl<OpFoldResult> &results) {
5508 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5509 return failure();
5510 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5511 if (!read)
5512 return failure();
5513
5514 if (!checkSameValueWAR(read, write))
5515 return failure();
5516 results.push_back(read.getBase());
5517 return success();
5518}
5519
5520LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5521 SmallVectorImpl<OpFoldResult> &results) {
5522 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5523 return success();
5524 if (succeeded(foldWAR(*this, results)))
5525 return success();
5526 if (succeeded(foldTransferInBoundsAttribute(*this)))
5527 return success();
5528 if (succeeded(foldTransferFullMask(*this)))
5529 return success();
5530 return memref::foldMemRefCast(*this);
5531}
5532
5533//===----------------------------------------------------------------------===//
5534// TransferWriteOp: other methods.
5535//===----------------------------------------------------------------------===//
5536std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5537 return llvm::to_vector<4>(getVectorType().getShape());
5538}
5539
5540void TransferWriteOp::getEffects(
5541 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5542 &effects) {
5543 if (llvm::isa<MemRefType>(getShapedType()))
5544 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5545 SideEffects::DefaultResource::get());
5546}
5547
5548Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5549 if (hasPureTensorSemantics())
5552}
5553
5554namespace {
5555/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5556/// DCE
5557/// ```
5558/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5559/// : vector<1x4xf32>, tensor<4x4xf32>
5560/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
5561/// : vector<1x4xf32>, tensor<4x4xf32>
5562/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5563/// : vector<1x4xf32>, tensor<4x4xf32>
5564/// ```
5565///
5566/// into:
5567///
5568/// ```
5569/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5570/// : vector<1x4xf32>, tensor<4x4xf32>
5571/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
5572/// : vector<1x4xf32>, tensor<4x4xf32>
5573/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
5574/// : vector<1x4xf32>, tensor<4x4xf32>
5575/// ```
5576///
5577/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
5578/// any other uses.
5579class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
5580public:
5581 using Base::Base;
5582 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
5583 PatternRewriter &rewriter) const override {
5584 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
5585 return failure();
5586 vector::TransferWriteOp writeToModify = writeOp;
5587
5588 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5589 while (defWrite) {
5590 if (checkSameValueWAW(writeOp, defWrite)) {
5591 rewriter.modifyOpInPlace(writeToModify, [&]() {
5592 writeToModify.getBaseMutable().assign(defWrite.getBase());
5593 });
5594 return success();
5595 }
5597 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5598 cast<VectorTransferOpInterface>(writeOp.getOperation())))
5599 break;
5600 // If the previous write op doesn't have any other use we an safely look
5601 // at the previous store to see if it can be removed.
5602 if (!defWrite->hasOneUse())
5603 break;
5604 writeToModify = defWrite;
5605 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5606 }
5607 return failure();
5608 }
5609};
5610
5611/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
5612/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
5613/// overwritten and inserted into another tensor. After this rewrite, the
5614/// operations bufferize in-place since all of them work on the same slice.
5615///
5616/// For example:
5617/// ```mlir
5618/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
5619/// : vector<8x16xf32>, tensor<8x16xf32>
5620/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
5621/// : tensor<8x16xf32> to tensor<?x?xf32>
5622/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5623/// : tensor<?x?xf32> into tensor<27x37xf32>
5624/// ```
5625/// folds to
5626/// ```mlir
5627/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5628/// : tensor<27x37xf32> to tensor<?x?xf32>
5629/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
5630/// : vector<8x16xf32>, tensor<?x?xf32>
5631/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
5632/// : tensor<?x?xf32> into tensor<27x37xf32>
5633/// ```
5634struct SwapExtractSliceOfTransferWrite
5635 : public OpRewritePattern<tensor::InsertSliceOp> {
5636public:
5637 using Base::Base;
5638
5639 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
5640 PatternRewriter &rewriter) const override {
5641 if (!insertOp.hasUnitStride())
5642 return failure();
5643 auto extractOp =
5644 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
5645 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
5646 return failure();
5647 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
5648 if (!transferOp || !transferOp->hasOneUse())
5649 return failure();
5650
5651 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
5652 // rank-reducing.
5653 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
5654 return rewriter.notifyMatchFailure(insertOp,
5655 "use-def chain is rank-reducing");
5656 }
5657
5658 // Fail if tensor::ExtractSliceOp has non-zero offset.
5659 if (!extractOp.hasZeroOffset()) {
5660 return rewriter.notifyMatchFailure(insertOp,
5661 "ExtractSliceOp has non-zero offset");
5662 }
5663
5664 // Fail if tensor::TransferWriteOp has non-zero offset.
5665 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
5666 return getConstantIntValue(value) == static_cast<int64_t>(0);
5667 })) {
5668 return rewriter.notifyMatchFailure(insertOp,
5669 "TranferWriteOp has non-zero offset");
5670 }
5671
5672 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
5673 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
5674 return rewriter.notifyMatchFailure(
5675 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
5676 }
5677
5678 for (auto [insertSize, extractSize] :
5679 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
5680 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
5681 return rewriter.notifyMatchFailure(
5682 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
5683 }
5684 }
5685
5686 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
5687 assert(transferOp.getVectorType().hasStaticShape() &&
5688 "expected vector to have a static shape");
5689 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
5690 SmallVector<int64_t> resultShape = applyPermutationMap(
5691 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
5692 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
5693 return rewriter.notifyMatchFailure(
5694 insertOp, "TransferWriteOp may not write the full tensor.");
5695 }
5696
5697 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
5698 // Set all in_bounds to false and let the folder infer them.
5699 SmallVector<bool> newInBounds(vectorShape.size(), false);
5700 auto newExtractOp = tensor::ExtractSliceOp::create(
5701 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
5702 insertOp.getDest(), insertOp.getMixedOffsets(),
5703 insertOp.getMixedSizes(), insertOp.getMixedStrides());
5704 auto newTransferWriteOp = TransferWriteOp::create(
5705 rewriter, transferOp.getLoc(), transferOp.getVector(),
5706 newExtractOp.getResult(), transferOp.getIndices(),
5707 transferOp.getPermutationMapAttr(),
5708 rewriter.getBoolArrayAttr(newInBounds));
5709 rewriter.modifyOpInPlace(insertOp, [&]() {
5710 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
5711 });
5712 return success();
5713 }
5714};
5715
5716} // namespace
5717
5718void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
5719 MLIRContext *context) {
5720 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
5721}
5722
5723FailureOr<std::optional<SmallVector<Value>>>
5724TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
5725 if (!hasPureBufferSemantics())
5726 return failure();
5728 ValueRange());
5729}
5730
5731//===----------------------------------------------------------------------===//
5732// LoadOp
5733//===----------------------------------------------------------------------===//
5734
5735static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
5736 VectorType vecTy,
5737 MemRefType memRefTy) {
5738 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
5739 // need any strides limitations.
5740 if (!vecTy.isScalable() &&
5741 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
5742 return success();
5743
5744 if (!memRefTy.isLastDimUnitStride())
5745 return op->emitOpError("most minor memref dim must have unit stride");
5746 return success();
5747}
5748
5749LogicalResult vector::LoadOp::verify() {
5750 VectorType resVecTy = getVectorType();
5751 MemRefType memRefTy = getMemRefType();
5752
5753 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
5754 return failure();
5755
5756 if (memRefTy.getRank() < resVecTy.getRank())
5757 return emitOpError(
5758 "destination memref has lower rank than the result vector");
5759
5760 // Checks for vector memrefs.
5761 Type memElemTy = memRefTy.getElementType();
5762 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5763 if (memVecTy != resVecTy)
5764 return emitOpError("base memref and result vector types should match");
5765 memElemTy = memVecTy.getElementType();
5766 }
5767
5768 if (resVecTy.getElementType() != memElemTy)
5769 return emitOpError("base and result element types should match");
5770 if (llvm::size(getIndices()) != memRefTy.getRank())
5771 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5772 return success();
5773}
5774
5775OpFoldResult LoadOp::fold(FoldAdaptor) {
5776 if (succeeded(memref::foldMemRefCast(*this)))
5777 return getResult();
5778 return OpFoldResult();
5779}
5780
5781std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5782 return llvm::to_vector<4>(getVectorType().getShape());
5783}
5784
5785FailureOr<std::optional<SmallVector<Value>>>
5786LoadOp::bubbleDownCasts(OpBuilder &builder) {
5788 getResult());
5789}
5790
5791//===----------------------------------------------------------------------===//
5792// StoreOp
5793//===----------------------------------------------------------------------===//
5794
5795LogicalResult vector::StoreOp::verify() {
5796 VectorType valueVecTy = getVectorType();
5797 MemRefType memRefTy = getMemRefType();
5798
5799 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
5800 return failure();
5801
5802 if (memRefTy.getRank() < valueVecTy.getRank())
5803 return emitOpError("source memref has lower rank than the vector to store");
5804
5805 // Checks for vector memrefs.
5806 Type memElemTy = memRefTy.getElementType();
5807 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
5808 if (memVecTy != valueVecTy)
5809 return emitOpError(
5810 "base memref and valueToStore vector types should match");
5811 memElemTy = memVecTy.getElementType();
5812 }
5813
5814 if (valueVecTy.getElementType() != memElemTy)
5815 return emitOpError("base and valueToStore element type should match");
5816 if (llvm::size(getIndices()) != memRefTy.getRank())
5817 return emitOpError("requires ") << memRefTy.getRank() << " indices";
5818 return success();
5819}
5820
5821LogicalResult StoreOp::fold(FoldAdaptor adaptor,
5822 SmallVectorImpl<OpFoldResult> &results) {
5823 return memref::foldMemRefCast(*this);
5824}
5825
5826std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5827 return llvm::to_vector<4>(getVectorType().getShape());
5828}
5829
5830FailureOr<std::optional<SmallVector<Value>>>
5831StoreOp::bubbleDownCasts(OpBuilder &builder) {
5833 ValueRange());
5834}
5835
5836//===----------------------------------------------------------------------===//
5837// MaskedLoadOp
5838//===----------------------------------------------------------------------===//
5839
5840LogicalResult MaskedLoadOp::verify() {
5841 VectorType maskVType = getMaskVectorType();
5842 VectorType passVType = getPassThruVectorType();
5843 VectorType resVType = getVectorType();
5844 MemRefType memType = getMemRefType();
5845
5846 if (resVType.getElementType() != memType.getElementType())
5847 return emitOpError("base and result element type should match");
5848 if (llvm::size(getIndices()) != memType.getRank())
5849 return emitOpError("requires ") << memType.getRank() << " indices";
5850 if (resVType.getShape() != maskVType.getShape())
5851 return emitOpError("expected result shape to match mask shape");
5852 if (resVType != passVType)
5853 return emitOpError("expected pass_thru of same type as result type");
5854 return success();
5855}
5856
5857namespace {
5858class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
5859public:
5860 using Base::Base;
5861 LogicalResult matchAndRewrite(MaskedLoadOp load,
5862 PatternRewriter &rewriter) const override {
5863 switch (getMaskFormat(load.getMask())) {
5865 rewriter.replaceOpWithNewOp<vector::LoadOp>(
5866 load, load.getType(), load.getBase(), load.getIndices());
5867 return success();
5869 rewriter.replaceOp(load, load.getPassThru());
5870 return success();
5872 return failure();
5873 }
5874 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
5875 }
5876};
5877} // namespace
5878
5879void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5880 MLIRContext *context) {
5881 results.add<MaskedLoadFolder>(context);
5882}
5883
5884OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
5885 if (succeeded(memref::foldMemRefCast(*this)))
5886 return getResult();
5887 return OpFoldResult();
5888}
5889
5890FailureOr<std::optional<SmallVector<Value>>>
5891MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
5893 getResult());
5894}
5895
5896//===----------------------------------------------------------------------===//
5897// MaskedStoreOp
5898//===----------------------------------------------------------------------===//
5899
5900LogicalResult MaskedStoreOp::verify() {
5901 VectorType maskVType = getMaskVectorType();
5902 VectorType valueVType = getVectorType();
5903 MemRefType memType = getMemRefType();
5904
5905 if (valueVType.getElementType() != memType.getElementType())
5906 return emitOpError("base and valueToStore element type should match");
5907 if (llvm::size(getIndices()) != memType.getRank())
5908 return emitOpError("requires ") << memType.getRank() << " indices";
5909 if (valueVType.getShape() != maskVType.getShape())
5910 return emitOpError("expected valueToStore shape to match mask shape");
5911 return success();
5912}
5913
5914namespace {
5915class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
5916public:
5917 using Base::Base;
5918 LogicalResult matchAndRewrite(MaskedStoreOp store,
5919 PatternRewriter &rewriter) const override {
5920 switch (getMaskFormat(store.getMask())) {
5922 rewriter.replaceOpWithNewOp<vector::StoreOp>(
5923 store, store.getValueToStore(), store.getBase(), store.getIndices());
5924 return success();
5926 rewriter.eraseOp(store);
5927 return success();
5929 return failure();
5930 }
5931 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
5932 }
5933};
5934} // namespace
5935
5936void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
5937 MLIRContext *context) {
5938 results.add<MaskedStoreFolder>(context);
5939}
5940
5941LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
5942 SmallVectorImpl<OpFoldResult> &results) {
5943 return memref::foldMemRefCast(*this);
5944}
5945
5946FailureOr<std::optional<SmallVector<Value>>>
5947MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
5949 ValueRange());
5950}
5951
5952//===----------------------------------------------------------------------===//
5953// GatherOp
5954//===----------------------------------------------------------------------===//
5955
5956LogicalResult GatherOp::verify() {
5957 VectorType indVType = getIndexVectorType();
5958 VectorType maskVType = getMaskVectorType();
5959 VectorType resVType = getVectorType();
5960 ShapedType baseType = getBaseType();
5961
5962 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
5963 return emitOpError("requires base to be a memref or ranked tensor type");
5964
5965 if (resVType.getElementType() != baseType.getElementType())
5966 return emitOpError("base and result element type should match");
5967 if (llvm::size(getOffsets()) != baseType.getRank())
5968 return emitOpError("requires ") << baseType.getRank() << " indices";
5969 if (resVType.getShape() != indVType.getShape())
5970 return emitOpError("expected result dim to match indices dim");
5971 if (resVType.getShape() != maskVType.getShape())
5972 return emitOpError("expected result dim to match mask dim");
5973 if (resVType != getPassThruVectorType())
5974 return emitOpError("expected pass_thru of same type as result type");
5975 return success();
5976}
5977
5978// MaskableOpInterface methods.
5979
5980/// Returns the mask type expected by this operation. Mostly used for
5981/// verification purposes. It requires the operation to be vectorized."
5982Type GatherOp::getExpectedMaskType() {
5983 auto vecType = this->getIndexVectorType();
5984 return VectorType::get(vecType.getShape(),
5985 IntegerType::get(vecType.getContext(), /*width=*/1),
5986 vecType.getScalableDims());
5987}
5988
5989std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
5990 return llvm::to_vector<4>(getVectorType().getShape());
5991}
5992
5993/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5994static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5995 auto vecType = dyn_cast<VectorType>(indexVec.getType());
5996 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5997 return failure();
5998
5999 if (indexVec.getDefiningOp<StepOp>())
6000 return success();
6001
6002 DenseIntElementsAttr elements;
6003 if (!matchPattern(indexVec, m_Constant(&elements)))
6004 return failure();
6005
6006 return success(
6007 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6008}
6009
6010namespace {
6011class GatherFolder final : public OpRewritePattern<GatherOp> {
6012public:
6013 using Base::Base;
6014 LogicalResult matchAndRewrite(GatherOp gather,
6015 PatternRewriter &rewriter) const override {
6016 switch (getMaskFormat(gather.getMask())) {
6018 return failure(); // no unmasked equivalent
6020 rewriter.replaceOp(gather, gather.getPassThru());
6021 return success();
6023 return failure();
6024 }
6025 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6026 }
6027};
6028
6029/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6030/// maskedload. Only 1D fixed vectors are supported for now.
6031class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6032public:
6033 using Base::Base;
6034 LogicalResult matchAndRewrite(GatherOp op,
6035 PatternRewriter &rewriter) const override {
6036 if (!isa<MemRefType>(op.getBase().getType()))
6037 return rewriter.notifyMatchFailure(op, "base must be of memref type");
6038
6039 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6040 return failure();
6041
6042 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6043 op.getOffsets(), op.getMask(),
6044 op.getPassThru());
6045 return success();
6046 }
6047};
6048} // namespace
6049
6050void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6051 MLIRContext *context) {
6052 results.add<GatherFolder, FoldContiguousGather>(context);
6053}
6054
6055FailureOr<std::optional<SmallVector<Value>>>
6056GatherOp::bubbleDownCasts(OpBuilder &builder) {
6058 getResult());
6059}
6060
6061//===----------------------------------------------------------------------===//
6062// ScatterOp
6063//===----------------------------------------------------------------------===//
6064
6065LogicalResult ScatterOp::verify() {
6066 VectorType indVType = getIndexVectorType();
6067 VectorType maskVType = getMaskVectorType();
6068 VectorType valueVType = getVectorType();
6069 ShapedType baseType = getBaseType();
6070
6071 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6072 return emitOpError("requires base to be a memref or ranked tensor type");
6073
6074 if (valueVType.getElementType() != baseType.getElementType())
6075 return emitOpError("base and valueToStore element type should match");
6076 if (llvm::size(getOffsets()) != baseType.getRank())
6077 return emitOpError("requires ") << baseType.getRank() << " indices";
6078 if (valueVType.getShape() != indVType.getShape())
6079 return emitOpError("expected valueToStore dim to match indices dim");
6080 if (valueVType.getShape() != maskVType.getShape())
6081 return emitOpError("expected valueToStore dim to match mask dim");
6082 return success();
6083}
6084namespace {
6085class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6086public:
6087 using Base::Base;
6088 LogicalResult matchAndRewrite(ScatterOp scatter,
6089 PatternRewriter &rewriter) const override {
6090 switch (getMaskFormat(scatter.getMask())) {
6092 return failure(); // no unmasked equivalent
6094 rewriter.eraseOp(scatter);
6095 return success();
6097 return failure();
6098 }
6099 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6100 }
6101};
6102
6103/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6104/// maskedstore. Only 1D fixed vectors are supported for now.
6105class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6106public:
6107 using Base::Base;
6108 LogicalResult matchAndRewrite(ScatterOp op,
6109 PatternRewriter &rewriter) const override {
6110 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6111 return failure();
6112
6113 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6114 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6115 return success();
6116 }
6117};
6118} // namespace
6119
6120void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6121 MLIRContext *context) {
6122 results.add<ScatterFolder, FoldContiguousScatter>(context);
6123}
6124
6125FailureOr<std::optional<SmallVector<Value>>>
6126ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6128 ValueRange());
6129}
6130
6131//===----------------------------------------------------------------------===//
6132// ExpandLoadOp
6133//===----------------------------------------------------------------------===//
6134
6135LogicalResult ExpandLoadOp::verify() {
6136 VectorType maskVType = getMaskVectorType();
6137 VectorType passVType = getPassThruVectorType();
6138 VectorType resVType = getVectorType();
6139 MemRefType memType = getMemRefType();
6140
6141 if (resVType.getElementType() != memType.getElementType())
6142 return emitOpError("base and result element type should match");
6143 if (llvm::size(getIndices()) != memType.getRank())
6144 return emitOpError("requires ") << memType.getRank() << " indices";
6145 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
6146 return emitOpError("expected result dim to match mask dim");
6147 if (resVType != passVType)
6148 return emitOpError("expected pass_thru of same type as result type");
6149 return success();
6150}
6151
6152namespace {
6153class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6154public:
6155 using Base::Base;
6156 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6157 PatternRewriter &rewriter) const override {
6158 switch (getMaskFormat(expand.getMask())) {
6160 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6161 expand, expand.getType(), expand.getBase(), expand.getIndices());
6162 return success();
6164 rewriter.replaceOp(expand, expand.getPassThru());
6165 return success();
6167 return failure();
6168 }
6169 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6170 }
6171};
6172} // namespace
6173
6174void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6175 MLIRContext *context) {
6176 results.add<ExpandLoadFolder>(context);
6177}
6178
6179FailureOr<std::optional<SmallVector<Value>>>
6180ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6182 getResult());
6183}
6184
6185//===----------------------------------------------------------------------===//
6186// CompressStoreOp
6187//===----------------------------------------------------------------------===//
6188
6189LogicalResult CompressStoreOp::verify() {
6190 VectorType maskVType = getMaskVectorType();
6191 VectorType valueVType = getVectorType();
6192 MemRefType memType = getMemRefType();
6193
6194 if (valueVType.getElementType() != memType.getElementType())
6195 return emitOpError("base and valueToStore element type should match");
6196 if (llvm::size(getIndices()) != memType.getRank())
6197 return emitOpError("requires ") << memType.getRank() << " indices";
6198 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
6199 return emitOpError("expected valueToStore dim to match mask dim");
6200 return success();
6201}
6202
6203namespace {
6204class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6205public:
6206 using Base::Base;
6207 LogicalResult matchAndRewrite(CompressStoreOp compress,
6208 PatternRewriter &rewriter) const override {
6209 switch (getMaskFormat(compress.getMask())) {
6211 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6212 compress, compress.getValueToStore(), compress.getBase(),
6213 compress.getIndices());
6214 return success();
6216 rewriter.eraseOp(compress);
6217 return success();
6219 return failure();
6220 }
6221 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6222 }
6223};
6224} // namespace
6225
6226void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6227 MLIRContext *context) {
6228 results.add<CompressStoreFolder>(context);
6229}
6230
6231FailureOr<std::optional<SmallVector<Value>>>
6232CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6234 ValueRange());
6235}
6236
6237//===----------------------------------------------------------------------===//
6238// ShapeCastOp
6239//===----------------------------------------------------------------------===//
6240
6241void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6242 SetIntRangeFn setResultRanges) {
6243 setResultRanges(getResult(), argRanges.front());
6244}
6245
6246std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6247 return llvm::to_vector<4>(getResultVectorType().getShape());
6248}
6249
6250LogicalResult ShapeCastOp::verify() {
6251
6252 VectorType sourceType = getSourceVectorType();
6253 VectorType resultType = getResultVectorType();
6254
6255 // Check that element type is preserved
6256 if (sourceType.getElementType() != resultType.getElementType())
6257 return emitOpError("has different source and result element types");
6258
6259 // Check that number of elements is preserved
6260 int64_t sourceNElms = sourceType.getNumElements();
6261 int64_t resultNElms = resultType.getNumElements();
6262 if (sourceNElms != resultNElms) {
6263 return emitOpError() << "has different number of elements at source ("
6264 << sourceNElms << ") and result (" << resultNElms
6265 << ")";
6266 }
6267
6268 // Check that (non-)scalability is preserved
6269 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6270 int64_t resultNScalableDims = resultType.getNumScalableDims();
6271 if (sourceNScalableDims != resultNScalableDims)
6272 return emitOpError() << "has different number of scalable dims at source ("
6273 << sourceNScalableDims << ") and result ("
6274 << resultNScalableDims << ")";
6275
6276 return success();
6277}
6278
6279/// Return true if `transpose` does not permute a pair of non-unit dims.
6280/// By `order preserving` we mean that the flattened versions of the input and
6281/// output vectors are (numerically) identical. In other words `transpose` is
6282/// effectively a shape cast.
6283static bool isOrderPreserving(TransposeOp transpose) {
6284 ArrayRef<int64_t> permutation = transpose.getPermutation();
6285 VectorType sourceType = transpose.getSourceVectorType();
6286 ArrayRef<int64_t> inShape = sourceType.getShape();
6287 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6288 auto isNonScalableUnitDim = [&](int64_t dim) {
6289 return inShape[dim] == 1 && !inDimIsScalable[dim];
6290 };
6291 int64_t current = 0;
6292 for (auto p : permutation) {
6293 if (!isNonScalableUnitDim(p)) {
6294 if (p < current) {
6295 return false;
6296 }
6297 current = p;
6298 }
6299 }
6300 return true;
6301}
6302
6303OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6304
6305 VectorType resultType = getType();
6306
6307 // No-op shape cast.
6308 if (getSource().getType() == resultType)
6309 return getSource();
6310
6311 // shape_cast(shape_cast(x)) -> shape_cast(x)
6312 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6313 setOperand(precedingShapeCast.getSource());
6314 return getResult();
6315 }
6316
6317 // shape_cast(transpose(x)) -> shape_cast(x)
6318 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6319 if (isOrderPreserving(transpose)) {
6320 setOperand(transpose.getVector());
6321 return getResult();
6322 }
6323 return {};
6324 }
6325
6326 // Y = shape_cast(broadcast(X))
6327 // -> X, if X and Y have same type
6328 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6329 if (bcastOp.getSourceType() == resultType)
6330 return bcastOp.getSource();
6331 }
6332
6333 // shape_cast(constant) -> constant
6334 if (auto denseAttr =
6335 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6336 return denseAttr.reshape(getType());
6337
6338 // shape_cast(poison) -> poison
6339 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
6340 return ub::PoisonAttr::get(getContext());
6341
6342 return {};
6343}
6344
6345namespace {
6346
6347/// Helper function that computes a new vector type based on the input vector
6348/// type by removing the trailing one dims:
6349///
6350/// vector<4x1x1xi1> --> vector<4x1xi1>
6351///
6352static VectorType trimTrailingOneDims(VectorType oldType) {
6353 ArrayRef<int64_t> oldShape = oldType.getShape();
6354 ArrayRef<int64_t> newShape = oldShape;
6355
6356 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6357 ArrayRef<bool> newScalableDims = oldScalableDims;
6358
6359 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6360 newShape = newShape.drop_back(1);
6361 newScalableDims = newScalableDims.drop_back(1);
6362 }
6363
6364 // Make sure we have at least 1 dimension.
6365 // TODO: Add support for 0-D vectors.
6366 if (newShape.empty()) {
6367 newShape = oldShape.take_back();
6368 newScalableDims = oldScalableDims.take_back();
6369 }
6370
6371 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6372}
6373
6374/// Folds qualifying shape_cast(create_mask) into a new create_mask
6375///
6376/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6377/// dimension. If the input vector comes from `vector.create_mask` for which
6378/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6379/// to fold shape_cast into create_mask.
6380///
6381/// BEFORE:
6382/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6383/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6384/// AFTER:
6385/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6386class ShapeCastCreateMaskFolderTrailingOneDim final
6387 : public OpRewritePattern<ShapeCastOp> {
6388public:
6389 using Base::Base;
6390
6391 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6392 PatternRewriter &rewriter) const override {
6393 Value shapeOpSrc = shapeOp->getOperand(0);
6394 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6395 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6396 if (!createMaskOp && !constantMaskOp)
6397 return failure();
6398
6399 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6400 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6401
6402 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6403 if (newVecType != shapeOpResTy)
6404 return failure();
6405
6406 auto numDimsToDrop =
6407 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6408
6409 // No unit dims to drop
6410 if (!numDimsToDrop)
6411 return failure();
6412
6413 if (createMaskOp) {
6414 auto maskOperands = createMaskOp.getOperands();
6415 auto numMaskOperands = maskOperands.size();
6416
6417 // Check every mask dim size to see whether it can be dropped
6418 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6419 --i) {
6420 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6421 if (!constant || (constant.value() != 1))
6422 return failure();
6423 }
6424 SmallVector<Value> newMaskOperands =
6425 maskOperands.drop_back(numDimsToDrop);
6426
6427 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6428 newMaskOperands);
6429 return success();
6430 }
6431
6432 if (constantMaskOp) {
6433 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6434 auto numMaskOperands = maskDimSizes.size();
6435
6436 // Check every mask dim size to see whether it can be dropped
6437 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6438 --i) {
6439 if (maskDimSizes[i] != 1)
6440 return failure();
6441 }
6442
6443 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6444 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6445 newMaskOperands);
6446 return success();
6447 }
6448
6449 return failure();
6450 }
6451};
6452
6453/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
6454/// i) Y = ShapeCast(X), or
6455/// ii) Y = Broadcast(X)
6456/// If both (i) and (ii) are possible, (i) is chosen.
6457class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6458public:
6459 using Base::Base;
6460
6461 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6462 PatternRewriter &rewriter) const override {
6463 auto broadcastOp =
6464 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6465 if (!broadcastOp)
6466 return failure();
6467
6468 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6469 bool srcIsScalar = !srcVectorType;
6470
6471 // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
6472 // Example:
6473 // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
6474 // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
6475 // to
6476 // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
6477 if (srcVectorType) {
6478 if (srcVectorType.getNumElements() ==
6479 shapeCastOp.getResultVectorType().getNumElements()) {
6480 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
6481 shapeCastOp, shapeCastOp.getResultVectorType(),
6482 broadcastOp.getSource());
6483 return success();
6484 }
6485 }
6486
6487 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6488 // Example
6489 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6490 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6491 // to
6492 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6493 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6494 if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
6495 BroadcastableToResult::Success) {
6496 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6497 shapeCastOp, dstVectorType, broadcastOp.getSource());
6498 return success();
6499 }
6500 return failure();
6501 }
6502};
6503
6504} // namespace
6505
6506void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
6507 MLIRContext *context) {
6508 results
6509 .add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder>(
6510 context);
6511}
6512
6513//===----------------------------------------------------------------------===//
6514// VectorBitCastOp
6515//===----------------------------------------------------------------------===//
6516
6517LogicalResult BitCastOp::verify() {
6518 auto sourceVectorType = getSourceVectorType();
6519 auto resultVectorType = getResultVectorType();
6520
6521 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
6522 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
6523 return emitOpError("dimension size mismatch at: ") << i;
6524 }
6525
6526 DataLayout dataLayout = DataLayout::closest(*this);
6527 auto sourceElementBits =
6528 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
6529 auto resultElementBits =
6530 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
6531
6532 if (sourceVectorType.getRank() == 0) {
6533 if (sourceElementBits != resultElementBits)
6534 return emitOpError("source/result bitwidth of the 0-D vector element "
6535 "types must be equal");
6536 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
6537 resultElementBits * resultVectorType.getShape().back()) {
6538 return emitOpError(
6539 "source/result bitwidth of the minor 1-D vectors must be equal");
6540 }
6541
6542 return success();
6543}
6544
6545OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
6546 // Nop cast.
6547 if (getSource().getType() == getResult().getType())
6548 return getSource();
6549
6550 // Canceling bitcasts.
6551 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
6552 if (getResult().getType() == otherOp.getSource().getType())
6553 return otherOp.getSource();
6554
6555 setOperand(otherOp.getSource());
6556 return getResult();
6557 }
6558
6559 Attribute sourceConstant = adaptor.getSource();
6560 if (!sourceConstant)
6561 return {};
6562
6563 Type srcElemType = getSourceVectorType().getElementType();
6564 Type dstElemType = getResultVectorType().getElementType();
6565
6566 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
6567 if (floatPack.isSplat()) {
6568 auto splat = floatPack.getSplatValue<FloatAttr>();
6569
6570 // Casting fp16 into fp32.
6571 if (srcElemType.isF16() && dstElemType.isF32()) {
6572 uint32_t bits = static_cast<uint32_t>(
6573 splat.getValue().bitcastToAPInt().getZExtValue());
6574 // Duplicate the 16-bit pattern.
6575 bits = (bits << 16) | (bits & 0xffff);
6576 APInt intBits(32, bits);
6577 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
6578 return DenseElementsAttr::get(getResultVectorType(), floatBits);
6579 }
6580 }
6581 }
6582
6583 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
6584 if (intPack.isSplat()) {
6585 auto splat = intPack.getSplatValue<IntegerAttr>();
6586
6587 if (llvm::isa<IntegerType>(dstElemType)) {
6588 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
6589 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
6590
6591 // Casting to a larger integer bit width.
6592 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
6593 APInt intBits = splat.getValue().zext(dstBitWidth);
6594
6595 // Duplicate the lower width element.
6596 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
6597 intBits = (intBits << srcBitWidth) | intBits;
6598 return DenseElementsAttr::get(getResultVectorType(), intBits);
6599 }
6600 }
6601 }
6602 }
6603
6604 return {};
6605}
6606
6607//===----------------------------------------------------------------------===//
6608// TypeCastOp
6609//===----------------------------------------------------------------------===//
6610
6611static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
6612 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
6613 SmallVector<int64_t, 8> res(memRefType.getShape());
6614 if (vectorType)
6615 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
6616 return res;
6617}
6618
6619/// Build the canonical memRefType with a single vector.
6620/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
6621void TypeCastOp::build(OpBuilder &builder, OperationState &result,
6622 Value source) {
6623 result.addOperands(source);
6624 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
6625 VectorType vectorType =
6626 VectorType::get(extractShape(memRefType),
6628 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
6629 memRefType.getMemorySpace()));
6630}
6631
6632LogicalResult TypeCastOp::verify() {
6633 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
6634 if (!canonicalType.getLayout().isIdentity())
6635 return emitOpError("expects operand to be a memref with identity layout");
6636 if (!getResultMemRefType().getLayout().isIdentity())
6637 return emitOpError("expects result to be a memref with identity layout");
6638 if (getResultMemRefType().getMemorySpace() !=
6639 getMemRefType().getMemorySpace())
6640 return emitOpError("expects result in same memory space");
6641
6642 auto sourceType = getMemRefType();
6643 auto resultType = getResultMemRefType();
6644 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
6646 return emitOpError(
6647 "expects result and operand with same underlying scalar type: ")
6648 << resultType;
6649 if (extractShape(sourceType) != extractShape(resultType))
6650 return emitOpError(
6651 "expects concatenated result and operand shapes to be equal: ")
6652 << resultType;
6653 return success();
6654}
6655
6656//===----------------------------------------------------------------------===//
6657// TransposeOp
6658//===----------------------------------------------------------------------===//
6659
6660void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
6661 Value vector, ArrayRef<int64_t> permutation) {
6662 VectorType vt = llvm::cast<VectorType>(vector.getType());
6663 SmallVector<int64_t, 4> transposedShape(vt.getRank());
6664 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
6665 for (unsigned i = 0; i < permutation.size(); ++i) {
6666 transposedShape[i] = vt.getShape()[permutation[i]];
6667 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
6668 }
6669
6670 result.addOperands(vector);
6671 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
6672 transposedScalableDims));
6673 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
6674 builder.getDenseI64ArrayAttr(permutation));
6675}
6676
6677OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
6678 // Eliminate splat constant transpose ops.
6679 if (auto splat =
6680 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
6681 return splat.reshape(getResultVectorType());
6682
6683 // Eliminate poison transpose ops.
6684 if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
6685 return ub::PoisonAttr::get(getContext());
6686
6687 // Eliminate identity transposes, and more generally any transposes that
6688 // preserves the shape without permuting elements.
6689 //
6690 // Examples of what to fold:
6691 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6692 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6693 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6694 //
6695 // Example of what NOT to fold:
6696 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6697 //
6698 if (getSourceVectorType() == getResultVectorType() &&
6699 isOrderPreserving(*this))
6700 return getVector();
6701
6702 return {};
6703}
6704
6705LogicalResult vector::TransposeOp::verify() {
6706 VectorType vectorType = getSourceVectorType();
6707 VectorType resultType = getResultVectorType();
6708 int64_t rank = resultType.getRank();
6709 if (vectorType.getRank() != rank)
6710 return emitOpError("vector result rank mismatch: ") << rank;
6711 // Verify transposition array.
6712 ArrayRef<int64_t> perm = getPermutation();
6713 int64_t size = perm.size();
6714 if (rank != size)
6715 return emitOpError("transposition length mismatch: ") << size;
6716 SmallVector<bool, 8> seen(rank, false);
6717 for (const auto &ta : llvm::enumerate(perm)) {
6718 if (ta.value() < 0 || ta.value() >= rank)
6719 return emitOpError("transposition index out of range: ") << ta.value();
6720 if (seen[ta.value()])
6721 return emitOpError("duplicate position index: ") << ta.value();
6722 seen[ta.value()] = true;
6723 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
6724 return emitOpError("dimension size mismatch at: ") << ta.value();
6725 }
6726 return success();
6727}
6728
6729std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
6730 return llvm::to_vector<4>(getResultVectorType().getShape());
6731}
6732
6733void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6734 SetIntRangeFn setResultRanges) {
6735 setResultRanges(getResult(), argRanges.front());
6736}
6737
6738namespace {
6739
6740// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
6741class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6742public:
6743 using Base::Base;
6744
6745 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6746 PatternRewriter &rewriter) const override {
6747 // Composes two permutations: result[i] = permutation1[permutation2[i]].
6748 auto composePermutations = [](ArrayRef<int64_t> permutation1,
6749 ArrayRef<int64_t> permutation2) {
6750 SmallVector<int64_t, 4> result;
6751 for (auto index : permutation2)
6752 result.push_back(permutation1[index]);
6753 return result;
6754 };
6755
6756 // Return if the input of 'transposeOp' is not defined by another transpose.
6757 vector::TransposeOp parentTransposeOp =
6758 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
6759 if (!parentTransposeOp)
6760 return failure();
6761
6762 SmallVector<int64_t, 4> permutation = composePermutations(
6763 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
6764 // Replace 'transposeOp' with a new transpose operation.
6765 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
6766 transposeOp, transposeOp.getResult().getType(),
6767 parentTransposeOp.getVector(), permutation);
6768 return success();
6769 }
6770};
6771
6772/// Replace transpose(splat-like(v)) with broadcast(v)
6773class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6774public:
6775 using Base::Base;
6776
6777 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6778 PatternRewriter &rewriter) const override {
6779 Value splat = getScalarSplatSource(transposeOp.getVector());
6780 if (!splat)
6781 return failure();
6782
6783 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6784 transposeOp, transposeOp.getResultVectorType(), splat);
6785 return success();
6786 }
6787};
6788
6789/// Folds transpose(create_mask) into a new transposed create_mask.
6790class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
6791public:
6792 using Base::Base;
6793
6794 LogicalResult matchAndRewrite(TransposeOp transpOp,
6795 PatternRewriter &rewriter) const override {
6796 Value transposeSrc = transpOp.getVector();
6797 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
6798 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
6799 if (!createMaskOp && !constantMaskOp)
6800 return failure();
6801
6802 // Get the transpose permutation and apply it to the vector.create_mask or
6803 // vector.constant_mask operands.
6804 ArrayRef<int64_t> permutation = transpOp.getPermutation();
6805
6806 if (createMaskOp) {
6807 auto maskOperands = createMaskOp.getOperands();
6808 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
6809 applyPermutationToVector(newOperands, permutation);
6810
6811 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
6812 transpOp, transpOp.getResultVectorType(), newOperands);
6813 return success();
6814 }
6815
6816 // ConstantMaskOp case.
6817 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6818 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
6819
6820 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
6821 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
6822 return success();
6823 }
6824};
6825
6826/// Folds transpose(shape_cast) into a new shape_cast.
6827class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
6828public:
6829 using Base::Base;
6830
6831 LogicalResult matchAndRewrite(TransposeOp transposeOp,
6832 PatternRewriter &rewriter) const override {
6833 auto shapeCastOp =
6834 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
6835 if (!shapeCastOp)
6836 return failure();
6837 if (!isOrderPreserving(transposeOp))
6838 return failure();
6839
6840 VectorType resultType = transposeOp.getType();
6841
6842 // We don't need to check isValidShapeCast at this point, because it is
6843 // guaranteed that merging the transpose into the the shape_cast is a valid
6844 // shape_cast, because the transpose just inserts/removes ones.
6845
6846 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
6847 shapeCastOp.getSource());
6848 return success();
6849 }
6850};
6851
6852/// Folds transpose(from_elements(...)) into a new from_elements with permuted
6853/// operands matching the transposed shape.
6854///
6855/// Example:
6856///
6857/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
6858/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
6859/// vector<3x2xi32>
6860///
6861/// becomes ->
6862///
6863/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
6864/// vector<3x2xi32>
6865///
6866class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
6867public:
6868 using Base::Base;
6869 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6870 PatternRewriter &rewriter) const override {
6871 auto fromElementsOp =
6872 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
6873 if (!fromElementsOp)
6874 return failure();
6875
6876 VectorType srcTy = fromElementsOp.getDest().getType();
6877 VectorType dstTy = transposeOp.getType();
6878
6879 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
6880 int64_t rank = srcTy.getRank();
6881
6882 // Build inverse permutation to map destination indices back to source.
6883 SmallVector<int64_t> inversePerm(rank, 0);
6884 for (int64_t i = 0; i < rank; ++i)
6885 inversePerm[permutation[i]] = i;
6886
6887 ArrayRef<int64_t> srcShape = srcTy.getShape();
6888 ArrayRef<int64_t> dstShape = dstTy.getShape();
6889 SmallVector<int64_t> srcIdx(rank, 0);
6890 SmallVector<int64_t> dstIdx(rank, 0);
6891 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
6892 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
6893
6894 auto elementsOld = fromElementsOp.getElements();
6895 SmallVector<Value> elementsNew;
6896 int64_t dstNumElements = dstTy.getNumElements();
6897 elementsNew.reserve(dstNumElements);
6898
6899 // For each element in destination row-major order, pick the corresponding
6900 // source element.
6901 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
6902 // Pick the destination element index.
6903 dstIdx = delinearize(linearIdx, dstStrides);
6904 // Map the destination element index to the source element index.
6905 for (int64_t j = 0; j < rank; ++j)
6906 srcIdx[j] = dstIdx[inversePerm[j]];
6907 // Linearize the source element index.
6908 int64_t srcLin = linearize(srcIdx, srcStrides);
6909 // Add the source element to the new elements.
6910 elementsNew.push_back(elementsOld[srcLin]);
6911 }
6912
6913 rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
6914 elementsNew);
6915 return success();
6916 }
6917};
6918
6919/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
6920/// 'order preserving', where 'order preserving' means the flattened
6921/// inputs and outputs of the transpose have identical (numerical) values.
6922///
6923/// Example:
6924/// ```
6925/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
6926/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
6927/// to vector<8x1xi32>
6928/// ```
6929/// can be rewritten as the equivalent
6930/// ```
6931/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
6932/// ```
6933/// The algorithm works by partitioning dimensions into groups that can be
6934/// locally permuted while preserving order, and checks that the transpose
6935/// only permutes within these groups.
6936///
6937/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
6938/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
6939/// broadcasting from 1x1x4x1x1x7.
6940/// ^^^ ^ ^^^ ^
6941/// groups: 0 1 2 3
6942/// Order preserving permutations for this example are ones that only permute
6943/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
6944class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
6945public:
6946 using Base::Base;
6947 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
6948 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
6949
6950 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
6951 PatternRewriter &rewriter) const override {
6952
6953 vector::BroadcastOp broadcast =
6954 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
6955 if (!broadcast) {
6956 return rewriter.notifyMatchFailure(transpose,
6957 "not preceded by a broadcast");
6958 }
6959
6960 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
6961 VectorType outputType = transpose.getResultVectorType();
6962
6963 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
6964 bool inputIsScalar = !inputType;
6965 if (inputIsScalar) {
6966 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
6967 broadcast.getSource());
6968 return success();
6969 }
6970
6971 ArrayRef<int64_t> permutation = transpose.getPermutation();
6972 ArrayRef<int64_t> inputShape = inputType.getShape();
6973 int64_t inputRank = inputType.getRank();
6974 int64_t outputRank = transpose.getType().getRank();
6975 int64_t deltaRank = outputRank - inputRank;
6976
6977 int low = 0;
6978 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
6979 bool notOne = inputShape[inputIndex] != 1;
6980 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
6981 bool groupEndFound = notOne || prevNotOne;
6982 if (groupEndFound) {
6983 int high = inputIndex + deltaRank;
6984 // Return failure if not all permutation destinations for indices in
6985 // [low, high) are in [low, high), i.e. the permutation is not local to
6986 // the group.
6987 for (int i = low; i < high; ++i) {
6988 if (permutation[i] < low || permutation[i] >= high) {
6989 return rewriter.notifyMatchFailure(
6990 transpose, "permutation not local to group");
6991 }
6992 }
6993 low = high;
6994 }
6995 }
6996
6997 // We don't need to check the final group [low, outputRank) because if it is
6998 // not locally bound, there must be a preceding group that already failed
6999 // the check (impossible to have just 1 non-locally bound group).
7000
7001 // The preceding logic also ensures that at this point, the output of the
7002 // transpose is definitely broadcastable from the input shape, assert so:
7003 assert(vector::isBroadcastableTo(inputType, outputType) ==
7004 vector::BroadcastableToResult::Success &&
7005 "not broadcastable directly to transpose output");
7006
7007 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7008 broadcast.getSource());
7009
7010 return success();
7011 }
7012};
7013
7014} // namespace
7015
7016void vector::TransposeOp::getCanonicalizationPatterns(
7017 RewritePatternSet &results, MLIRContext *context) {
7018 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7019 FoldTransposeSplat, FoldTransposeFromElements,
7020 FoldTransposeBroadcast>(context);
7021}
7022
7023//===----------------------------------------------------------------------===//
7024// ConstantMaskOp
7025//===----------------------------------------------------------------------===//
7026
7027void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
7028 VectorType type, ConstantMaskKind kind) {
7029 assert(kind == ConstantMaskKind::AllTrue ||
7030 kind == ConstantMaskKind::AllFalse);
7031 build(builder, result, type,
7032 kind == ConstantMaskKind::AllTrue
7033 ? type.getShape()
7034 : SmallVector<int64_t>(type.getRank(), 0));
7035}
7036
7037LogicalResult ConstantMaskOp::verify() {
7038 auto resultType = llvm::cast<VectorType>(getResult().getType());
7039 // Check the corner case of 0-D vectors first.
7040 if (resultType.getRank() == 0) {
7041 if (getMaskDimSizes().size() != 1)
7042 return emitError("array attr must have length 1 for 0-D vectors");
7043 auto dim = getMaskDimSizes()[0];
7044 if (dim != 0 && dim != 1)
7045 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
7046 return success();
7047 }
7048
7049 // Verify that array attr size matches the rank of the vector result.
7050 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
7051 return emitOpError(
7052 "must specify array attr of size equal vector result rank");
7053 // Verify that each array attr element is in bounds of corresponding vector
7054 // result dimension size.
7055 auto resultShape = resultType.getShape();
7056 auto resultScalableDims = resultType.getScalableDims();
7057 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7058 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7059 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7060 return emitOpError(
7061 "array attr of size out of bounds of vector result dimension size");
7062 if (resultScalableDims[index] && maskDimSize != 0 &&
7063 maskDimSize != resultShape[index])
7064 return emitOpError(
7065 "only supports 'none set' or 'all set' scalable dimensions");
7066 }
7067 // Verify that if one mask dim size is zero, they all should be zero (because
7068 // the mask region is a conjunction of each mask dimension interval).
7069 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7070 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
7071 if (anyZeros && !allZeros)
7072 return emitOpError("expected all mask dim sizes to be zeros, "
7073 "as a result of conjunction with zero mask dim");
7074 return success();
7075}
7076
7077bool ConstantMaskOp::isAllOnesMask() {
7078 auto resultType = getVectorType();
7079 // Check the corner case of 0-D vectors first.
7080 if (resultType.getRank() == 0) {
7081 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7082 return getMaskDimSizes()[0] == 1;
7083 }
7084 for (const auto [resultSize, maskDimSize] :
7085 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7086 if (maskDimSize < resultSize)
7087 return false;
7088 }
7089 return true;
7090}
7091
7092OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7093 ArrayRef<int64_t> bounds = getMaskDimSizes();
7094 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7095
7096 auto createBoolSplat = [&](bool x) {
7097 return SplatElementsAttr::get(getVectorType(),
7099 };
7100
7101 // Check the corner case of 0-D vectors first.
7102 if (vectorSizes.empty()) {
7103 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7104 return createBoolSplat(bounds[0] == 1);
7105 }
7106 // Fold vector.constant_mask to splat if possible.
7107 if (bounds == vectorSizes)
7108 return createBoolSplat(true);
7109 if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7110 return createBoolSplat(false);
7111 return OpFoldResult();
7112}
7113
7114//===----------------------------------------------------------------------===//
7115// CreateMaskOp
7116//===----------------------------------------------------------------------===//
7117
7118void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7119 VectorType type,
7120 ArrayRef<OpFoldResult> mixedOperands) {
7121 SmallVector<Value> operands =
7122 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7123 build(builder, result, type, operands);
7124}
7125
7126LogicalResult CreateMaskOp::verify() {
7127 auto vectorType = llvm::cast<VectorType>(getResult().getType());
7128 // Verify that an operand was specified for each result vector each dimension.
7129 if (vectorType.getRank() == 0) {
7130 if (getNumOperands() != 1)
7131 return emitOpError(
7132 "must specify exactly one operand for 0-D create_mask");
7133 } else if (getNumOperands() !=
7134 llvm::cast<VectorType>(getResult().getType()).getRank()) {
7135 return emitOpError(
7136 "must specify an operand for each result vector dimension");
7137 }
7138 return success();
7139}
7140
7141namespace {
7142
7143/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7144///
7145/// Ex 1:
7146/// %c2 = arith.constant 2 : index
7147/// %c3 = arith.constant 3 : index
7148/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7149/// Becomes:
7150/// vector.constant_mask [3, 2] : vector<4x3xi1>
7151///
7152/// Ex 2:
7153/// %c_neg_1 = arith.constant -1 : index
7154/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7155/// becomes:
7156/// vector.constant_mask [0] : vector<[8]xi1>
7157///
7158/// Ex 3:
7159/// %c8 = arith.constant 8 : index
7160/// %c16 = arith.constant 16 : index
7161/// %0 = vector.vscale
7162/// %1 = arith.muli %0, %c16 : index
7163/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7164/// becomes:
7165/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7166class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7167public:
7168 using Base::Base;
7169
7170 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7171 PatternRewriter &rewriter) const override {
7172 VectorType maskType = createMaskOp.getVectorType();
7173 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7174 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7175
7176 // Special case: Rank zero shape.
7177 constexpr std::array<int64_t, 1> rankZeroShape{1};
7178 constexpr std::array<bool, 1> rankZeroScalableDims{false};
7179 if (maskType.getRank() == 0) {
7180 maskTypeDimSizes = rankZeroShape;
7181 maskTypeDimScalableFlags = rankZeroScalableDims;
7182 }
7183
7184 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7185 // collect the `constantDims` (for the ConstantMaskOp).
7186 SmallVector<int64_t, 4> constantDims;
7187 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7188 if (auto intSize = getConstantIntValue(dimSize)) {
7189 // Constant value.
7190 // If the mask dim is non-scalable this can be any value.
7191 // If the mask dim is scalable only zero (all-false) is supported.
7192 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7193 return failure();
7194 constantDims.push_back(*intSize);
7195 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7196 // Constant vscale multiple (e.g. 4 x vscale).
7197 // Must be all-true to fold to a ConstantMask.
7198 if (vscaleMultiplier < maskTypeDimSizes[i])
7199 return failure();
7200 constantDims.push_back(*vscaleMultiplier);
7201 } else {
7202 return failure();
7203 }
7204 }
7205
7206 // Clamp values to constant_mask bounds.
7207 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7208 value = std::clamp<int64_t>(value, 0, maskDimSize);
7209
7210 // If one of dim sizes is zero, set all dims to zero.
7211 if (llvm::is_contained(constantDims, 0))
7212 constantDims.assign(constantDims.size(), 0);
7213
7214 // Replace 'createMaskOp' with ConstantMaskOp.
7215 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7216 constantDims);
7217 return success();
7218 }
7219};
7220
7221} // namespace
7222
7223void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7224 MLIRContext *context) {
7225 results.add<CreateMaskFolder>(context);
7226}
7227
7228//===----------------------------------------------------------------------===//
7229// MaskOp
7230//===----------------------------------------------------------------------===//
7231
7232void MaskOp::build(
7233 OpBuilder &builder, OperationState &result, Value mask,
7234 Operation *maskableOp,
7235 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7236 assert(maskRegionBuilder &&
7237 "builder callback for 'maskRegion' must be present");
7238
7239 result.addOperands(mask);
7240 OpBuilder::InsertionGuard guard(builder);
7241 Region *maskRegion = result.addRegion();
7242 builder.createBlock(maskRegion);
7243 maskRegionBuilder(builder, maskableOp);
7244}
7245
7246void MaskOp::build(
7247 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7248 Value mask, Operation *maskableOp,
7249 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7250 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7251 maskRegionBuilder);
7252}
7253
7254void MaskOp::build(
7255 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7256 Value mask, Value passthru, Operation *maskableOp,
7257 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7258 build(builder, result, mask, maskableOp, maskRegionBuilder);
7259 if (passthru)
7260 result.addOperands(passthru);
7261 result.addTypes(resultTypes);
7262}
7263
7264ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7265 // Create the op region.
7266 result.regions.reserve(1);
7267 Region &maskRegion = *result.addRegion();
7268
7269 auto &builder = parser.getBuilder();
7270
7271 // Parse all the operands.
7272 OpAsmParser::UnresolvedOperand mask;
7273 if (parser.parseOperand(mask))
7274 return failure();
7275
7276 // Optional passthru operand.
7277 OpAsmParser::UnresolvedOperand passthru;
7278 ParseResult parsePassthru = parser.parseOptionalComma();
7279 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7280 return failure();
7281
7282 // Parse op region.
7283 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7284 return failure();
7285
7286 MaskOp::ensureTerminator(maskRegion, builder, result.location);
7287
7288 // Parse the optional attribute list.
7289 if (parser.parseOptionalAttrDict(result.attributes))
7290 return failure();
7291
7292 // Parse all the types.
7293 Type maskType;
7294 if (parser.parseColonType(maskType))
7295 return failure();
7296
7297 SmallVector<Type> resultTypes;
7298 if (parser.parseOptionalArrowTypeList(resultTypes))
7299 return failure();
7300 result.types.append(resultTypes);
7301
7302 // Resolve operands.
7303 if (parser.resolveOperand(mask, maskType, result.operands))
7304 return failure();
7305
7306 if (parsePassthru.succeeded()) {
7307 if (resultTypes.empty())
7308 return parser.emitError(
7309 parser.getNameLoc(),
7310 "expects a result if passthru operand is provided");
7311
7312 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7313 return failure();
7314 }
7315
7316 return success();
7317}
7318
7319void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7320 p << " " << getMask();
7321 if (getPassthru())
7322 p << ", " << getPassthru();
7323
7324 // Print single masked operation and skip terminator.
7325 p << " { ";
7326 Block *singleBlock = &getMaskRegion().getBlocks().front();
7327 if (singleBlock && !singleBlock->getOperations().empty())
7328 p.printCustomOrGenericOp(&singleBlock->front());
7329 p << " }";
7330
7331 p.printOptionalAttrDict(getOperation()->getAttrs());
7332
7333 p << " : " << getMask().getType();
7334 if (getNumResults() > 0)
7335 p << " -> " << getResultTypes();
7336}
7337
7338void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7339 // 1. For an empty `vector.mask`, create a default terminator.
7340 if (region.empty() || region.front().empty()) {
7341 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7342 MaskOp>::ensureTerminator(region, builder, loc);
7343 return;
7344 }
7345
7346 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7347 Block &block = region.front();
7348 if (isa<vector::YieldOp>(block.back()))
7349 return;
7350
7351 // 3. For a non-empty `vector.mask` without an explicit terminator:
7352
7353 // Create default terminator if the number of masked operations is not
7354 // one. This case will trigger a verification failure.
7355 if (block.getOperations().size() != 1) {
7356 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7357 MaskOp>::ensureTerminator(region, builder, loc);
7358 return;
7359 }
7360
7361 // Create a terminator that yields the results from the masked operation.
7362 OpBuilder opBuilder(builder.getContext());
7363 Operation *maskedOp = &block.front();
7364 opBuilder.setInsertionPointToEnd(&block);
7365 vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7366}
7367
7368LogicalResult MaskOp::verify() {
7369 // Structural checks.
7370 Block &block = getMaskRegion().getBlocks().front();
7371 if (block.getOperations().empty())
7372 return emitOpError("expects a terminator within the mask region");
7373
7374 unsigned numMaskRegionOps = block.getOperations().size();
7375 if (numMaskRegionOps > 2)
7376 return emitOpError("expects only one operation to mask");
7377
7378 // Terminator checks.
7379 auto terminator = dyn_cast<vector::YieldOp>(block.back());
7380 if (!terminator)
7381 return emitOpError("expects a terminator within the mask region");
7382
7383 if (terminator->getNumOperands() != getNumResults())
7384 return emitOpError(
7385 "expects number of results to match mask region yielded values");
7386
7387 // Empty vector.mask. Nothing else to check.
7388 if (numMaskRegionOps == 1)
7389 return success();
7390
7391 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7392 if (!maskableOp)
7393 return emitOpError("expects a MaskableOpInterface within the mask region");
7394
7395 // Result checks.
7396 if (maskableOp->getNumResults() != getNumResults())
7397 return emitOpError("expects number of results to match maskable operation "
7398 "number of results");
7399
7400 if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7401 return emitOpError("expects all the results from the MaskableOpInterface "
7402 "to match all the values returned by the terminator");
7403
7404 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7405 return emitOpError(
7406 "expects result type to match maskable operation result type");
7407
7408 if (llvm::count_if(maskableOp->getResultTypes(),
7409 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7410 return emitOpError("multiple vector results not supported");
7411
7412 // Mask checks.
7413 Type expectedMaskType = maskableOp.getExpectedMaskType();
7414 if (getMask().getType() != expectedMaskType)
7415 return emitOpError("expects a ")
7416 << expectedMaskType << " mask for the maskable operation";
7417
7418 // Passthru checks.
7419 Value passthru = getPassthru();
7420 if (passthru) {
7421 if (!maskableOp.supportsPassthru())
7422 return emitOpError(
7423 "doesn't expect a passthru argument for this maskable operation");
7424
7425 if (maskableOp->getNumResults() != 1)
7426 return emitOpError("expects result when passthru argument is provided");
7427
7428 if (passthru.getType() != maskableOp->getResultTypes()[0])
7429 return emitOpError("expects passthru type to match result type");
7430 }
7431
7432 return success();
7433}
7434
7435/// Folds empty `vector.mask` with no passthru operand and with or without
7436/// return values. For example:
7437///
7438/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7439/// vector<8xi1> -> vector<8xf32>
7440/// %1 = user_op %0 : vector<8xf32>
7441///
7442/// becomes:
7443///
7444/// %0 = user_op %a : vector<8xf32>
7445///
7446/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7447/// as it requires creating new operations.
7448
7449static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7450 SmallVectorImpl<OpFoldResult> &results) {
7451 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7452 return failure();
7453
7454 Block *block = maskOp.getMaskBlock();
7455 auto terminator = cast<vector::YieldOp>(block->front());
7456 if (terminator.getNumOperands() == 0) {
7457 // `vector.mask` has no results, just remove the `vector.mask`.
7458 return success();
7459 }
7460
7461 // `vector.mask` has results, propagate the results.
7462 llvm::append_range(results, terminator.getOperands());
7463 return success();
7464}
7465
7466LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7467 SmallVectorImpl<OpFoldResult> &results) {
7468 if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7469 return success();
7470
7471 MaskFormat maskFormat = getMaskFormat(getMask());
7472 if (maskFormat != MaskFormat::AllTrue)
7473 return failure();
7474
7475 // Move maskable operation outside of the `vector.mask` region.
7476 Operation *maskableOp = getMaskableOp();
7477 maskableOp->dropAllUses();
7478 maskableOp->moveBefore(getOperation());
7479
7480 llvm::append_range(results, maskableOp->getResults());
7481 return success();
7482}
7483
7484/// Canonialize empty `vector.mask` operations that can't be handled in
7485/// `VectorMask::fold` as they require creating new operations.
7486///
7487/// Example 1: Empty `vector.mask` with passthru operand.
7488///
7489/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
7490/// vector<8xi1> -> vector<8xf32>
7491///
7492/// becomes:
7493///
7494/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
7495///
7496class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
7497 using Base::Base;
7498
7499 LogicalResult matchAndRewrite(MaskOp maskOp,
7500 PatternRewriter &rewriter) const override {
7501 if (!maskOp.isEmpty())
7502 return failure();
7503
7504 if (!maskOp.hasPassthru())
7505 return failure();
7506
7507 Block *block = maskOp.getMaskBlock();
7508 auto terminator = cast<vector::YieldOp>(block->front());
7509 assert(terminator.getNumOperands() == 1 &&
7510 "expected one result when passthru is provided");
7511
7512 rewriter.replaceOpWithNewOp<arith::SelectOp>(
7513 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
7514 terminator.getOperand(0), maskOp.getPassthru());
7515
7516 return success();
7517 }
7518};
7519
7520void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7521 MLIRContext *context) {
7522 results.add<CanonializeEmptyMaskOp>(context);
7523}
7524
7525// MaskingOpInterface definitions.
7526
7527/// Returns the operation masked by this 'vector.mask'.
7528Operation *MaskOp::getMaskableOp() {
7529 Block *block = getMaskBlock();
7530 if (block->getOperations().size() < 2)
7531 return nullptr;
7532
7533 return &block->front();
7534}
7535
7536/// Returns true if 'vector.mask' has a passthru value.
7537bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
7538
7539//===----------------------------------------------------------------------===//
7540// ScanOp
7541//===----------------------------------------------------------------------===//
7542
7543LogicalResult ScanOp::verify() {
7544 VectorType srcType = getSourceType();
7545 VectorType initialType = getInitialValueType();
7546 // Check reduction dimension < rank.
7547 int64_t srcRank = srcType.getRank();
7548 int64_t reductionDim = getReductionDim();
7549 if (reductionDim >= srcRank)
7550 return emitOpError("reduction dimension ")
7551 << reductionDim << " has to be less than " << srcRank;
7552
7553 // Check that rank(initial_value) = rank(src) - 1.
7554 int64_t initialValueRank = initialType.getRank();
7555 if (initialValueRank != srcRank - 1)
7556 return emitOpError("initial value rank ")
7557 << initialValueRank << " has to be equal to " << srcRank - 1;
7558
7559 // Check shapes of initial value and src.
7560 ArrayRef<int64_t> srcShape = srcType.getShape();
7561 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
7562 SmallVector<int64_t> expectedShape;
7563 for (int i = 0; i < srcRank; i++) {
7564 if (i != reductionDim)
7565 expectedShape.push_back(srcShape[i]);
7566 }
7567 if (!llvm::equal(initialValueShapes, expectedShape)) {
7568 return emitOpError("incompatible input/initial value shapes");
7569 }
7570
7571 // Verify supported reduction kind.
7572 Type eltType = getDestType().getElementType();
7573 if (!isSupportedCombiningKind(getKind(), eltType))
7574 return emitOpError("unsupported reduction type ")
7575 << eltType << " for kind '" << stringifyCombiningKind(getKind())
7576 << "'";
7577
7578 return success();
7579}
7580
7582 RewritePatternSet &patterns, PatternBenefit benefit) {
7583 patterns
7584 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
7585 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
7586 StridedSliceConstantMaskFolder, TransposeFolder>(
7587 patterns.getContext(), benefit);
7588}
7589
7590Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
7591 CombiningKind kind, Value v1, Value acc,
7592 arith::FastMathFlagsAttr fastmath,
7593 Value mask) {
7594 Type t1 = getElementTypeOrSelf(v1.getType());
7595 Type tAcc = getElementTypeOrSelf(acc.getType());
7596 Value result;
7597
7598 switch (kind) {
7599 case CombiningKind::ADD:
7600 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7601 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
7602 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7603 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
7604 else
7605 llvm_unreachable("invalid value types for ADD reduction");
7606 break;
7607 case CombiningKind::AND:
7608 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7609 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
7610 break;
7611 case CombiningKind::MAXNUMF:
7612 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7613 "expected float values");
7614 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
7615 break;
7616 case CombiningKind::MAXIMUMF:
7617 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7618 "expected float values");
7619 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
7620 break;
7621 case CombiningKind::MINNUMF:
7622 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7623 "expected float values");
7624 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
7625 break;
7626 case CombiningKind::MINIMUMF:
7627 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
7628 "expected float values");
7629 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
7630 break;
7631 case CombiningKind::MAXSI:
7632 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7633 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
7634 break;
7635 case CombiningKind::MINSI:
7636 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7637 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
7638 break;
7639 case CombiningKind::MAXUI:
7640 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7641 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
7642 break;
7643 case CombiningKind::MINUI:
7644 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7645 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
7646 break;
7647 case CombiningKind::MUL:
7648 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
7649 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
7650 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
7651 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
7652 else
7653 llvm_unreachable("invalid value types for MUL reduction");
7654 break;
7655 case CombiningKind::OR:
7656 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7657 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
7658 break;
7659 case CombiningKind::XOR:
7660 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
7661 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
7662 break;
7663 };
7664
7665 assert(result && "unknown CombiningKind");
7666 return selectPassthru(b, mask, result, acc);
7667}
7668
7669//===----------------------------------------------------------------------===//
7670// StepOp
7671//===----------------------------------------------------------------------===//
7672
7673void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7674 SetIntRangeFn setResultRanges) {
7675 auto resultType = cast<VectorType>(getType());
7676 if (resultType.isScalable()) {
7677 return;
7678 }
7679 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
7680 APInt zero(bitwidth, 0);
7681 APInt high(bitwidth, resultType.getDimSize(0) - 1);
7682 ConstantIntRanges result = {zero, high, zero, high};
7683 setResultRanges(getResult(), result);
7684}
7685
7686namespace {
7687
7688/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
7689/// constant large enough such that the result is the same at all indices.
7690///
7691/// For example, rewrite the 'greater than' comparison below,
7692///
7693/// ```mlir
7694/// %cst = arith.constant dense<7> : vector<3xindex>
7695/// %stp = vector.step : vector<3xindex>
7696/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
7697/// ```
7698///
7699/// as,
7700///
7701/// ```mlir
7702/// %out = arith.constant dense<false> : vector<3xi1>.
7703/// ```
7704///
7705/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
7706/// is false at ALL indices we fold. If the constant was 1, then
7707/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
7708/// conservatively preferring the 'compact' vector.step representation.
7709///
7710/// Note: this folder only works for the case where the constant (`%cst` above)
7711/// is the second operand of the comparison. The arith.cmpi canonicalizer will
7712/// ensure that constants are always second (on the right).
7713struct StepCompareFolder : public OpRewritePattern<StepOp> {
7714 using Base::Base;
7715
7716 LogicalResult matchAndRewrite(StepOp stepOp,
7717 PatternRewriter &rewriter) const override {
7718 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
7719
7720 for (OpOperand &use : stepOp.getResult().getUses()) {
7721 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7722 if (!cmpiOp)
7723 continue;
7724
7725 // arith.cmpi canonicalizer makes constants final operands.
7726 const unsigned stepOperandNumber = use.getOperandNumber();
7727 if (stepOperandNumber != 0)
7728 continue;
7729
7730 // Check that operand 1 is a constant.
7731 unsigned constOperandNumber = 1;
7732 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7733 std::optional<int64_t> maybeConstValue =
7734 getConstantIntValue(otherOperand);
7735 if (!maybeConstValue.has_value())
7736 continue;
7737
7738 int64_t constValue = maybeConstValue.value();
7739 arith::CmpIPredicate pred = cmpiOp.getPredicate();
7740
7741 auto maybeSplat = [&]() -> std::optional<bool> {
7742 // Handle ult (unsigned less than) and uge (unsigned greater equal).
7743 if ((pred == arith::CmpIPredicate::ult ||
7744 pred == arith::CmpIPredicate::uge) &&
7745 stepSize <= constValue)
7746 return pred == arith::CmpIPredicate::ult;
7747
7748 // Handle ule and ugt.
7749 if ((pred == arith::CmpIPredicate::ule ||
7750 pred == arith::CmpIPredicate::ugt) &&
7751 stepSize - 1 <= constValue) {
7752 return pred == arith::CmpIPredicate::ule;
7753 }
7754
7755 // Handle eq and ne.
7756 if ((pred == arith::CmpIPredicate::eq ||
7757 pred == arith::CmpIPredicate::ne) &&
7758 stepSize <= constValue)
7759 return pred == arith::CmpIPredicate::ne;
7760
7761 return std::nullopt;
7762 }();
7763
7764 if (!maybeSplat.has_value())
7765 continue;
7766
7767 rewriter.setInsertionPointAfter(cmpiOp);
7768
7769 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7770 if (!type)
7771 continue;
7772
7773 auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
7774 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7775 type, boolAttr);
7776
7777 rewriter.replaceOp(cmpiOp, splat);
7778 return success();
7779 }
7780
7781 return failure();
7782 }
7783};
7784} // namespace
7785
7786void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
7787 MLIRContext *context) {
7788 results.add<StepCompareFolder>(context);
7789}
7790
7791//===----------------------------------------------------------------------===//
7792// Vector Masking Utilities
7793//===----------------------------------------------------------------------===//
7794
7795/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
7796/// as masked operation.
7797void mlir::vector::createMaskOpRegion(OpBuilder &builder,
7798 Operation *maskableOp) {
7799 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
7800 Block *insBlock = builder.getInsertionBlock();
7801 // Create a block and move the op to that block.
7802 insBlock->getOperations().splice(
7803 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
7804 YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
7805}
7806
7807/// Creates a vector.mask operation around a maskable operation. Returns the
7808/// vector.mask operation if the mask provided is valid. Otherwise, returns
7809/// the maskable operation itself.
7810Operation *mlir::vector::maskOperation(OpBuilder &builder,
7811 Operation *maskableOp, Value mask,
7812 Value passthru) {
7813 if (!mask)
7814 return maskableOp;
7815 if (passthru)
7816 return MaskOp::create(builder, maskableOp->getLoc(),
7817 maskableOp->getResultTypes(), mask, passthru,
7818 maskableOp, createMaskOpRegion);
7819 return MaskOp::create(builder, maskableOp->getLoc(),
7820 maskableOp->getResultTypes(), mask, maskableOp,
7822}
7823
7824/// Creates a vector select operation that picks values from `newValue` or
7825/// `passthru` for each result vector lane based on `mask`. This utility is used
7826/// to propagate the pass-thru value of vector.mask or for cases where only the
7827/// pass-thru value propagation is needed. VP intrinsics do not support
7828/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
7829/// usually able to match op + select patterns and fold them into a native
7830/// target instructions.
7831Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
7832 Value newValue, Value passthru) {
7833 if (!mask)
7834 return newValue;
7835
7836 return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
7837 mask, newValue, passthru);
7838}
7839
7840//===----------------------------------------------------------------------===//
7841// TableGen'd op method definitions
7842//===----------------------------------------------------------------------===//
7843
7844#define GET_ATTRDEF_CLASSES
7845#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
7846
7847#define GET_OP_CLASSES
7848#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
b getContext())
auto load
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition VectorOps.cpp:70
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
Definition VectorOps.cpp:60
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
#define mul(a, b)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
bool empty()
Definition Block.h:148
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
iterator begin()
Definition Block.h:143
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
IntegerType getI1Type()
Definition Builders.cpp:53
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
IndexType getIndexType()
Definition Builders.cpp:51
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:270
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:442
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:834
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
bool isF32() const
Definition Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:112
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:45
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition VectorOps.h:64
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition VectorOps.h:72
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const