MLIR 23.0.0git
VectorOps.cpp
Go to the documentation of this file.
1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
15
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/Builders.h"
34#include "mlir/IR/IRMapping.h"
38#include "mlir/IR/ValueRange.h"
41#include "mlir/Support/LLVM.h"
43#include "llvm/ADT/ArrayRef.h"
44#include "llvm/ADT/Repeated.h"
45#include "llvm/ADT/STLExtras.h"
46#include "llvm/ADT/SmallVector.h"
47#include "llvm/ADT/SmallVectorExtras.h"
48#include "llvm/ADT/StringSet.h"
49#include "llvm/ADT/TypeSwitch.h"
50#include "llvm/Support/Casting.h"
51
52#include <cassert>
53#include <cstdint>
54#include <numeric>
55
56#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
57// Pull in all enum type and utility function definitions.
58#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
59
60using namespace mlir;
61using namespace mlir::vector;
62
63/// Helper enum to classify mask value.
64enum class MaskFormat {
68};
69
70/// Helper method to classify a mask value. Currently, the method
71/// looks "under the hood" of a constant value with dense attributes
72/// and a constant mask operation (since the client may be called at
73/// various stages during progressive lowering).
75 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
76 // Inspect constant dense values. We count up for bits that
77 // are set, count down for bits that are cleared, and bail
78 // when a mix is detected.
79 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
80 int64_t val = 0;
81 for (bool b : denseElts.getValues<bool>())
82 if (b && val >= 0)
83 val++;
84 else if (!b && val <= 0)
85 val--;
86 else
88 if (val > 0)
90 if (val < 0)
92 }
93 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
94 // Inspect constant mask index. If the index exceeds the
95 // dimension size, all bits are set. If the index is zero
96 // or less, no bits are set.
97 ArrayRef<int64_t> masks = m.getMaskDimSizes();
98 auto shape = m.getType().getShape();
99 bool allTrue = true;
100 bool allFalse = true;
101 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
102 if (maskIdx < dimSize)
103 allTrue = false;
104 if (maskIdx > 0)
105 allFalse = false;
106 }
107 if (allTrue)
108 return MaskFormat::AllTrue;
109 if (allFalse)
111 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
112 // Finds all-false create_masks. An all-true create_mask requires all
113 // dims to be constants, so that'll be folded to a constant_mask, then
114 // detected in the constant_mask case.
115 auto maskOperands = m.getOperands();
116 for (Value operand : maskOperands) {
117 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
118 int64_t dimSize =
119 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
120 if (dimSize <= 0)
122 }
123 }
124 return MaskFormat::Unknown;
125 }
126 return MaskFormat::Unknown;
127}
128
129/// Default callback to build a region with a 'vector.yield' terminator with no
130/// arguments.
132 vector::YieldOp::create(builder, loc);
133}
134
135// Helper for verifying combining kinds in contractions and reductions.
136static bool isSupportedCombiningKind(CombiningKind combiningKind,
137 Type elementType) {
138 switch (combiningKind) {
139 case CombiningKind::ADD:
140 case CombiningKind::MUL:
141 return elementType.isIntOrIndexOrFloat();
142 case CombiningKind::MINUI:
143 case CombiningKind::MINSI:
144 case CombiningKind::MAXUI:
145 case CombiningKind::MAXSI:
146 case CombiningKind::AND:
147 case CombiningKind::OR:
148 case CombiningKind::XOR:
149 return elementType.isIntOrIndex();
150 case CombiningKind::MINNUMF:
151 case CombiningKind::MAXNUMF:
152 case CombiningKind::MINIMUMF:
153 case CombiningKind::MAXIMUMF:
154 return llvm::isa<FloatType>(elementType);
155 }
156 return false;
157}
158
159/// Returns the effective rank of the vector to read/write for Xfer Ops
160///
161/// When the element type of the shaped type is _a scalar_, this will simply
162/// return the rank of the vector ( the result for xfer_read or the value to
163/// store for xfer_write).
164///
165/// When the element type of the base shaped type is _a vector_, returns the
166/// difference between the original vector type and the element type of the
167/// shaped type.
168///
169/// EXAMPLE 1 (element type is _a scalar_):
170/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
171/// - shapedType.getElementType() = f32 (rank 0)
172/// - vectorType.getRank() = 2
173/// - Result = 2 - 0 = 2
174///
175/// EXAMPLE 2 (element type is _a vector_):
176/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
177/// - shapedType.getElementType() = vector<20xf32> (rank 1)
178/// - vectorType.getRank() = 1
179/// - Result = 1 - 1 = 0
180///
181/// This is used to determine the number of minor dimensions for identity maps
182/// in vector transfer Ops.
183static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
184 VectorType vectorType) {
185 unsigned elementVectorRank = 0;
186 VectorType elementVectorType =
187 llvm::dyn_cast<VectorType>(shapedType.getElementType());
188 if (elementVectorType)
189 elementVectorRank += elementVectorType.getRank();
190 return vectorType.getRank() - elementVectorRank;
191}
192
194 VectorType vectorType) {
195 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
196 // TODO: replace once we have 0-d vectors.
197 if (shapedType.getRank() == 0 &&
198 vectorType.getShape() == ArrayRef<int64_t>{1})
199 return AffineMap::get(
200 /*numDims=*/0, /*numSymbols=*/0,
201 getAffineConstantExpr(0, shapedType.getContext()));
203 shapedType.getRank(),
204 getEffectiveVectorRankForXferOp(shapedType, vectorType),
205 shapedType.getContext());
206}
207
208/// Check if `write` is of a constant splat and the masked `read` is padded with
209/// the same splat value -- meaning it could be the same value as the initial
210/// constant splat.
211static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
212 vector::TransferReadOp read) {
213 auto readMask = read.getMask();
214 auto writeMask = write.getMask();
215 // Check if the masks are consistent. The splat value could be the same if the
216 // read is masked (and padded with the splat value), and the write is unmasked
217 // or has the same mask. Note this does not allow the case where the write is
218 // masked and the read is unmasked, as then the read could be of more elements
219 // than the write (which may not be the same value).
220 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
221 if (!couldBeSameSplat)
222 return false;
223 // Check for constant splat (as the source of the write).
224 DenseElementsAttr splatAttr;
225 if (!matchPattern(write.getVector(),
226 m_Constant<DenseElementsAttr>(&splatAttr)) ||
227 !splatAttr.isSplat()) {
228 return false;
229 }
230 // The padding of the read and the constant splat value must be the same.
231 Attribute padAttr;
232 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
233 return false;
234 return padAttr == splatAttr.getSplatValue<Attribute>();
235}
236
237bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
238 vector::TransferReadOp read) {
239 return !defWrite.hasOutOfBoundsDim() &&
240 defWrite.getIndices() == read.getIndices() &&
241 defWrite.getVectorType() == read.getVectorType() &&
242 defWrite.getPermutationMap() == read.getPermutationMap() &&
243 ((!defWrite.getMask() && !read.getMask()) ||
245}
246
247bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
248 vector::TransferWriteOp priorWrite) {
249 return priorWrite.getIndices() == write.getIndices() &&
250 priorWrite.getMask() == write.getMask() &&
251 priorWrite.getVectorType() == write.getVectorType() &&
252 priorWrite.getPermutationMap() == write.getPermutationMap();
253}
254
256 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
257 bool testDynamicValueUsingBounds) {
258 // For simplicity only look at transfer of same type.
259 if (transferA.getVectorType() != transferB.getVectorType())
260 return false;
261 unsigned rankOffset = transferA.getLeadingShapedRank();
262 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
263 Value indexA = transferA.getIndices()[i];
264 Value indexB = transferB.getIndices()[i];
265 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
266 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
267
268 if (i < rankOffset) {
269 // For leading dimensions, if we can prove that index are different we
270 // know we are accessing disjoint slices.
271 if (cstIndexA.has_value() && cstIndexB.has_value()) {
272 if (*cstIndexA != *cstIndexB)
273 return true;
274 continue;
275 }
276 if (testDynamicValueUsingBounds) {
277 // First try to see if we can fully compose and simplify the affine
278 // expression as a fast track.
279 FailureOr<uint64_t> delta =
281 if (succeeded(delta) && *delta != 0)
282 return true;
283
284 FailureOr<bool> testEqual =
286 if (succeeded(testEqual) && !testEqual.value())
287 return true;
288 }
289 } else {
290 // For this dimension, we slice a part of the memref we need to make sure
291 // the intervals accessed don't overlap.
292 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
293 if (cstIndexA.has_value() && cstIndexB.has_value()) {
294 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
295 if (distance >= vectorDim)
296 return true;
297 continue;
298 }
299 if (testDynamicValueUsingBounds) {
300 // First try to see if we can fully compose and simplify the affine
301 // expression as a fast track.
302 FailureOr<int64_t> delta =
304 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
305 return true;
306
307 FailureOr<int64_t> computeDelta =
309 if (succeeded(computeDelta)) {
310 if (std::abs(computeDelta.value()) >= vectorDim)
311 return true;
312 }
313 }
314 }
315 }
316 return false;
317}
318
319bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
320 VectorTransferOpInterface transferB,
321 bool testDynamicValueUsingBounds) {
322 if (transferA.getBase() != transferB.getBase())
323 return false;
324 return isDisjointTransferIndices(transferA, transferB,
325 testDynamicValueUsingBounds);
326}
327
328// Helper to iterate over n-D vector slice elements. Calculate the next
329// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
330// Modifies the `position` in place. Returns a failure when `position` becomes
331// the end position.
332static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
334 ArrayRef<int64_t> offsets) {
335 for (auto [posInDim, dimSize, offsetInDim] :
336 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
337 ++posInDim;
338 if (posInDim < dimSize + offsetInDim)
339 return success();
340
341 // Carry the overflow to the next loop iteration.
342 posInDim = offsetInDim;
343 }
344
345 return failure();
346}
347
348/// Returns the integer numbers in `values`. `values` are expected to be
349/// constant operations.
352 llvm::transform(values, std::back_inserter(ints), [](Value value) {
353 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
354 assert(constOp && "Unexpected non-constant index");
355 return constOp.value();
356 });
357 return ints;
358}
359
360/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
361/// be constant operations.
364 llvm::transform(
365 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
366 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
367 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
368 });
369 return ints;
370}
371
372/// Convert `foldResults` into Values. Integer attributes are converted to
373/// constant op.
375 ArrayRef<OpFoldResult> foldResults) {
376 SmallVector<Value> values;
377 llvm::transform(foldResults, std::back_inserter(values),
378 [&](OpFoldResult foldResult) {
379 if (auto attr = dyn_cast<Attribute>(foldResult))
381 builder, loc, cast<IntegerAttr>(attr).getInt())
382 .getResult();
383
384 return cast<Value>(foldResult);
385 });
386 return values;
387}
388
389std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
390 if (value.getDefiningOp<vector::VectorScaleOp>())
391 return 1;
392 auto mul = value.getDefiningOp<arith::MulIOp>();
393 if (!mul)
394 return {};
395 auto lhs = mul.getLhs();
396 auto rhs = mul.getRhs();
397 if (lhs.getDefiningOp<vector::VectorScaleOp>())
398 return getConstantIntValue(rhs);
399 if (rhs.getDefiningOp<vector::VectorScaleOp>())
400 return getConstantIntValue(lhs);
401 return {};
402}
403
404/// Converts numeric attributes to the expected type. Supports
405/// integer-to-integer and float-to-integer conversions. Returns the original
406/// attribute if no conversion is needed or supported.
407static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
408 // Integer-to-integer conversion
409 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
410 if (auto intType = dyn_cast<IntegerType>(expectedType)) {
411 if (intAttr.getType() != expectedType)
412 return IntegerAttr::get(expectedType, intAttr.getInt());
413 }
414 return attr;
415 }
416
417 // Float-to-integer bitcast (preserves bit representation)
418 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
419 auto intType = dyn_cast<IntegerType>(expectedType);
420 if (!intType)
421 return attr;
422
423 APFloat floatVal = floatAttr.getValue();
424 APInt intVal = floatVal.bitcastToAPInt();
425 return IntegerAttr::get(expectedType, intVal);
426 }
427
428 return attr;
429}
430
431//===----------------------------------------------------------------------===//
432// CombiningKindAttr
433//===----------------------------------------------------------------------===//
434
435namespace mlir {
436namespace vector {
437namespace detail {
439 using KeyTy = uint64_t;
440
442
443 bool operator==(const KeyTy &key) const { return value == key; }
444
446 const KeyTy &key) {
447 return new (allocator.allocate<BitmaskEnumStorage>())
449 }
450
452};
453} // namespace detail
454} // namespace vector
455} // namespace mlir
456
457//===----------------------------------------------------------------------===//
458// VectorDialect
459//===----------------------------------------------------------------------===//
460
461namespace {
462/// This class defines the interface for handling inlining with vector dialect
463/// operations.
464struct VectorInlinerInterface : public DialectInlinerInterface {
465 using DialectInlinerInterface::DialectInlinerInterface;
466
467 /// All vector dialect ops can be inlined.
468 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
469 return true;
470 }
471};
472} // namespace
473
474void VectorDialect::initialize() {
475 addAttributes<
476#define GET_ATTRDEF_LIST
477#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
478 >();
479
480 addOperations<
481#define GET_OP_LIST
482#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
483 >();
484
485 addInterfaces<VectorInlinerInterface>();
486
487 declarePromisedInterfaces<memref::IndexedAccessOpInterface, LoadOp, StoreOp,
488 MaskedLoadOp, MaskedStoreOp, ExpandLoadOp,
489 CompressStoreOp>();
490 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
491 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
492 YieldOp>();
493 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
494 TransferWriteOp>();
495 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
496 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
497 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
498}
499
500/// Materialize a single constant operation from a given attribute value with
501/// the desired resultant type.
502Operation *VectorDialect::materializeConstant(OpBuilder &builder,
503 Attribute value, Type type,
504 Location loc) {
505 if (matchPattern(value, ub::m_Poison()))
506 return value.getDialect().materializeConstant(builder, value, type, loc);
507
508 return arith::ConstantOp::materialize(builder, value, type, loc);
509}
510
512 return builder.getIntegerType(64);
513}
514
516 ArrayRef<int64_t> values) {
517 return builder.getI64ArrayAttr(values);
518}
519
520//===----------------------------------------------------------------------===//
521// MultiDimReductionOp
522//===----------------------------------------------------------------------===//
523
524void vector::MultiDimReductionOp::build(OpBuilder &builder,
525 OperationState &result, Value source,
526 Value acc, ArrayRef<bool> reductionMask,
527 CombiningKind kind) {
528 SmallVector<int64_t> reductionDims;
529 for (const auto &en : llvm::enumerate(reductionMask))
530 if (en.value())
531 reductionDims.push_back(en.index());
532 build(builder, result, kind, source, acc, reductionDims);
533}
534
535OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
536 // No reduction dims: this is a noop regardless of rank.
537 if (getReductionDims().empty())
538 return getSource();
539 return {};
540}
541
542std::optional<SmallVector<int64_t, 4>>
543MultiDimReductionOp::getShapeForUnroll() {
544 return llvm::to_vector<4>(getSourceVectorType().getShape());
545}
546
547LogicalResult MultiDimReductionOp::verify() {
548 SmallVector<int64_t> targetShape;
549 SmallVector<bool> scalableDims;
550 Type inferredReturnType;
551 auto sourceScalableDims = getSourceVectorType().getScalableDims();
552 for (auto [dimIdx, dimSize] :
553 llvm::enumerate(getSourceVectorType().getShape()))
554 if (!llvm::any_of(getReductionDims(),
555 [dimIdx = dimIdx](int64_t reductionDimIdx) {
556 return reductionDimIdx == static_cast<int64_t>(dimIdx);
557 })) {
558 targetShape.push_back(dimSize);
559 scalableDims.push_back(sourceScalableDims[dimIdx]);
560 }
561 // TODO: update to also allow 0-d vectors when available.
562 if (targetShape.empty())
563 inferredReturnType = getSourceVectorType().getElementType();
564 else
565 inferredReturnType = VectorType::get(
566 targetShape, getSourceVectorType().getElementType(), scalableDims);
567 if (getType() != inferredReturnType)
568 return emitOpError() << "destination type " << getType()
569 << " is incompatible with source type "
570 << getSourceVectorType();
571
572 return success();
573}
574
575/// Returns the mask type expected by this operation.
576Type MultiDimReductionOp::getExpectedMaskType() {
577 auto vecType = getSourceVectorType();
578 return VectorType::get(vecType.getShape(),
579 IntegerType::get(vecType.getContext(), /*width=*/1),
580 vecType.getScalableDims());
581}
582
583namespace {
584// Only unit dimensions that are being reduced are folded. If the dimension is
585// unit, but not reduced, it is not folded, thereby keeping the output type the
586// same. If not all dimensions which are reduced are of unit dimension, this
587// transformation does nothing. This is just a generalization of
588// ElideSingleElementReduction for ReduceOp.
589struct ElideUnitDimsInMultiDimReduction
590 : public OpRewritePattern<MultiDimReductionOp> {
591 using Base::Base;
592
593 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
594 PatternRewriter &rewriter) const override {
595 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
596 for (const auto &dim : enumerate(shape)) {
597 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
598 return failure();
599 }
600
601 // Vector mask setup.
602 OpBuilder::InsertionGuard guard(rewriter);
603 Operation *rootOp;
604 Value mask;
605 if (reductionOp.isMasked()) {
606 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
607 rootOp = reductionOp.getMaskingOp();
608 mask = reductionOp.getMaskingOp().getMask();
609 } else {
610 rootOp = reductionOp;
611 }
612
613 Location loc = reductionOp.getLoc();
614 Value acc = reductionOp.getAcc();
615 Value cast;
616 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
617 if (mask) {
618 VectorType newMaskType =
619 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
620 dstVecType.getScalableDims());
621 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
622 }
623 cast = vector::ShapeCastOp::create(
624 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
625 } else {
626 // This means we are reducing all the dimensions, and all reduction
627 // dimensions are of size 1. So a simple extraction would do.
628 if (mask)
629 mask = vector::ExtractOp::create(rewriter, loc, mask);
630 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
631 }
632
633 Value result =
634 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
635 cast, /*fastmath=*/nullptr, mask);
636 rewriter.replaceOp(rootOp, result);
637 return success();
638 }
639};
640} // namespace
641
642void MultiDimReductionOp::getCanonicalizationPatterns(
643 RewritePatternSet &results, MLIRContext *context) {
644 results.add<ElideUnitDimsInMultiDimReduction>(context);
645}
646
647//===----------------------------------------------------------------------===//
648// ReductionOp
649//===----------------------------------------------------------------------===//
650
651void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
652 CombiningKind kind, Value vector,
653 arith::FastMathFlags fastMathFlags) {
654 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
655}
656
657void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
658 CombiningKind kind, Value vector, Value acc,
659 arith::FastMathFlags fastMathFlags) {
660 build(builder, result,
661 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
662 acc, fastMathFlags);
663}
664
665LogicalResult ReductionOp::verify() {
666 // Verify for 0-D and 1-D vector.
667 int64_t rank = getSourceVectorType().getRank();
668 if (rank > 1)
669 return emitOpError("unsupported reduction rank: ") << rank;
670
671 // Verify supported reduction kind.
672 Type eltType = getDest().getType();
673 if (!isSupportedCombiningKind(getKind(), eltType))
674 return emitOpError("unsupported reduction type '")
675 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
676 << "'";
677
678 return success();
679}
680
681// MaskableOpInterface methods.
682
683/// Returns the mask type expected by this operation.
684Type ReductionOp::getExpectedMaskType() {
685 auto vecType = getSourceVectorType();
686 return VectorType::get(vecType.getShape(),
687 IntegerType::get(vecType.getContext(), /*width=*/1),
688 vecType.getScalableDims());
689}
690
692 OpBuilder &builder, Location loc,
693 Value vector) {
694 switch (op) {
695 case arith::AtomicRMWKind::addf:
696 case arith::AtomicRMWKind::addi:
697 return vector::ReductionOp::create(builder, vector.getLoc(),
698 CombiningKind::ADD, vector);
699 case arith::AtomicRMWKind::mulf:
700 case arith::AtomicRMWKind::muli:
701 return vector::ReductionOp::create(builder, vector.getLoc(),
702 CombiningKind::MUL, vector);
703 case arith::AtomicRMWKind::minimumf:
704 return vector::ReductionOp::create(builder, vector.getLoc(),
705 CombiningKind::MINIMUMF, vector);
706 case arith::AtomicRMWKind::mins:
707 return vector::ReductionOp::create(builder, vector.getLoc(),
708 CombiningKind::MINSI, vector);
709 case arith::AtomicRMWKind::minu:
710 return vector::ReductionOp::create(builder, vector.getLoc(),
711 CombiningKind::MINUI, vector);
712 case arith::AtomicRMWKind::maximumf:
713 return vector::ReductionOp::create(builder, vector.getLoc(),
714 CombiningKind::MAXIMUMF, vector);
715 case arith::AtomicRMWKind::maxs:
716 return vector::ReductionOp::create(builder, vector.getLoc(),
717 CombiningKind::MAXSI, vector);
718 case arith::AtomicRMWKind::maxu:
719 return vector::ReductionOp::create(builder, vector.getLoc(),
720 CombiningKind::MAXUI, vector);
721 case arith::AtomicRMWKind::andi:
722 return vector::ReductionOp::create(builder, vector.getLoc(),
723 CombiningKind::AND, vector);
724 case arith::AtomicRMWKind::ori:
725 return vector::ReductionOp::create(builder, vector.getLoc(),
726 CombiningKind::OR, vector);
727 case arith::AtomicRMWKind::minnumf:
728 return vector::ReductionOp::create(builder, vector.getLoc(),
729 CombiningKind::MINNUMF, vector);
730 case arith::AtomicRMWKind::maxnumf:
731 return vector::ReductionOp::create(builder, vector.getLoc(),
732 CombiningKind::MAXNUMF, vector);
733 case arith::AtomicRMWKind::xori:
734 return vector::ReductionOp::create(builder, vector.getLoc(),
735 CombiningKind::XOR, vector);
736 default:
737 (void)emitOptionalError(loc, "Reduction operation type not supported");
738 break;
739 }
740 return nullptr;
741}
742
743std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
744 return llvm::to_vector<4>(getSourceVectorType().getShape());
745}
746
747namespace {
748struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
749 using Base::Base;
750
751 LogicalResult matchAndRewrite(ReductionOp reductionOp,
752 PatternRewriter &rewriter) const override {
753 // Vector mask setup.
754 OpBuilder::InsertionGuard guard(rewriter);
755 auto maskableOp =
756 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
757 Operation *rootOp;
758 Value mask;
759 if (maskableOp.isMasked()) {
760 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
761 rootOp = maskableOp.getMaskingOp();
762 mask = maskableOp.getMaskingOp().getMask();
763 } else {
764 rootOp = reductionOp;
765 }
766
767 auto vectorType = reductionOp.getSourceVectorType();
768 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
769 return failure();
770
771 Location loc = reductionOp.getLoc();
772 if (mask)
773 mask = ExtractOp::create(rewriter, loc, mask);
774 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
775
776 if (Value acc = reductionOp.getAcc())
777 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
778 result, acc,
779 reductionOp.getFastmathAttr(), mask);
780
781 rewriter.replaceOp(rootOp, result);
782 return success();
783 }
784};
785} // namespace
786
787void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
788 MLIRContext *context) {
789 results.add<ElideSingleElementReduction>(context);
790}
791
792//===----------------------------------------------------------------------===//
793// ContractionOp
794//===----------------------------------------------------------------------===//
795
796void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
798 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
799 ArrayRef<IteratorType> iteratorTypes) {
800 result.addOperands({lhs, rhs, acc});
801 result.addTypes(acc.getType());
802 result.addAttribute(
803 getIndexingMapsAttrName(result.name),
804 builder.getAffineMapArrayAttr(
805 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
806 result.addAttribute(
807 getIteratorTypesAttrName(result.name),
808 builder.getArrayAttr(llvm::map_to_vector(
809 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
810 return IteratorTypeAttr::get(builder.getContext(), t);
811 })));
812}
813
814void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
816 ArrayAttr indexingMaps,
817 ArrayAttr iteratorTypes) {
818 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
819 ContractionOp::getDefaultKind());
820}
821
822void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
824 ArrayAttr indexingMaps,
825 ArrayAttr iteratorTypes, CombiningKind kind,
826 arith::FastMathFlags fastMathFlags) {
827 result.addOperands({lhs, rhs, acc});
828 result.addTypes(acc.getType());
829 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
830 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
831 result.addAttribute(getKindAttrName(result.name),
832 CombiningKindAttr::get(builder.getContext(), kind));
833 if (fastMathFlags != arith::FastMathFlags::none)
834 result.addAttribute(
835 getFastmathAttrName(result.name),
836 arith::FastMathFlagsAttr::get(builder.getContext(), fastMathFlags));
837}
838
839ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
845 Type resultType;
846 auto loc = parser.getCurrentLocation();
847 DictionaryAttr dictAttr;
848 // TODO: Unify linalg op attribute parsing.
849 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
850 parser.parseComma() || parser.parseOperand(rhsInfo) ||
851 parser.parseComma() || parser.parseOperand(accInfo) ||
852 parser.parseTrailingOperandList(masksInfo) ||
853 parser.parseOptionalAttrDict(result.attributes) ||
854 parser.parseColonTypeList(types) ||
855 parser.parseKeywordType("into", resultType) ||
856 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
857 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
858 parser.resolveOperand(accInfo, resultType, result.operands) ||
859 parser.addTypeToList(resultType, result.types))
860 return failure();
861 result.attributes.append(dictAttr.getValue().begin(),
862 dictAttr.getValue().end());
863
864 // Convert array of string into an array of IteratyType enums. This is needed,
865 // because tests still use the old format when 'iterator_types' attribute is
866 // represented as an array of strings.
867 // TODO: Remove this conversion once tests are fixed.
868 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
869 result.attributes.get(getIteratorTypesAttrName(result.name)));
870 if (!iteratorTypes) {
871 return parser.emitError(loc)
872 << "expected " << getIteratorTypesAttrName(result.name)
873 << " array attribute";
874 }
875
876 SmallVector<Attribute> iteratorTypeAttrs;
877
878 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
879 auto maybeIteratorType = symbolizeIteratorType(s);
880 if (!maybeIteratorType.has_value())
881 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
882
883 iteratorTypeAttrs.push_back(
884 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
885 }
886 result.attributes.set(getIteratorTypesAttrName(result.name),
887 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
888
889 if (!result.attributes.get(getKindAttrName(result.name))) {
890 result.addAttribute(
891 getKindAttrName(result.name),
892 CombiningKindAttr::get(result.getContext(),
893 ContractionOp::getDefaultKind()));
894 }
895 if (masksInfo.empty())
896 return success();
897 if (masksInfo.size() != 2)
898 return parser.emitError(parser.getNameLoc(),
899 "expected zero or exactly 2 vector mask operands");
900 auto lhsType = llvm::cast<VectorType>(types[0]);
901 auto rhsType = llvm::cast<VectorType>(types[1]);
902 auto maskElementType = parser.getBuilder().getI1Type();
903 std::array<VectorType, 2> maskTypes = {
904 VectorType::Builder(lhsType).setElementType(maskElementType),
905 VectorType::Builder(rhsType).setElementType(maskElementType)};
906 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
907 return failure();
908 return success();
909}
910
911void ContractionOp::print(OpAsmPrinter &p) {
912 // TODO: Unify printing code with linalg ops.
913 auto attrNames = getTraitAttrNames();
914 llvm::StringSet<> traitAttrsSet;
915 traitAttrsSet.insert_range(attrNames);
917 for (auto attr : (*this)->getAttrs()) {
918 if (attr.getName() == getIteratorTypesAttrName()) {
919 auto iteratorTypes =
920 llvm::cast<ArrayAttr>(attr.getValue())
921 .getAsValueRange<IteratorTypeAttr, IteratorType>();
922 // Convert IteratorType enums into the string representation. This is
923 // needed, because tests still use the old format when 'iterator_types'
924 // attribute is represented as an array of strings.
925 // TODO: Remove this conversion once tests are fixed.
926 SmallVector<Attribute> iteratorTypeNames =
927 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) -> Attribute {
928 return StringAttr::get(getContext(), stringifyIteratorType(t));
929 });
930
931 attrs.emplace_back(getIteratorTypesAttrName(),
932 ArrayAttr::get(getContext(), iteratorTypeNames));
933 } else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
934 // Omit fastmath when it equals the default (none) to keep output clean.
935 if (attr.getName() == getFastmathAttrName() &&
936 llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
937 arith::FastMathFlags::none)
938 continue;
939 attrs.push_back(attr);
940 }
941 }
942
943 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
944 p << " " << dictAttr << " " << getLhs() << ", ";
945 p << getRhs() << ", " << getAcc();
946
947 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
948 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
949 << getResultType();
950}
951
952static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
953 const std::vector<std::pair<int64_t, int64_t>> &map) {
954 for (auto &dimPair : map) {
955 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
956 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
957 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
958 return false;
959 }
960 return true;
961}
962
963static LogicalResult verifyOutputShape(
964 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
965 Type resType,
966 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
967 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
968 DenseSet<int64_t> lhsContractingDimSet;
969 DenseSet<int64_t> rhsContractingDimSet;
970 for (auto &dimPair : contractingDimMap) {
971 lhsContractingDimSet.insert(dimPair.first);
972 rhsContractingDimSet.insert(dimPair.second);
973 }
974 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
975 llvm::make_second_range(batchDimMap));
976
977 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
978 SmallVector<int64_t, 4> expectedResultDims;
979 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
980 if (lhsContractingDimSet.count(i) > 0)
981 continue;
982 expectedResultDims.push_back(lhsType.getDimSize(i));
983 }
984
985 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
986 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
987 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
988 continue;
989 expectedResultDims.push_back(rhsType.getDimSize(i));
990 }
991
992 // Verify 'expectedResultDims'.
993 if (expectedResultDims.empty()) {
994 // No batch or free dimension implies a scalar result.
995 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
996 return op.emitOpError("invalid accumulator/result vector shape");
997 } else {
998 // At least one batch or free dimension implies a vector result.
999 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
1000 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
1001 if (!resVectorType || !accVectorType)
1002 return op.emitOpError("invalid accumulator/result vector shape");
1003
1004 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
1005 // types fully define the result vector type. This assumes the affine maps
1006 // are well-formed, which must have been verified already.
1007 MLIRContext *ctx = op.getContext();
1008 AffineMap lhsMap = op.getIndexingMapsArray()[0];
1009 AffineMap rhsMap = op.getIndexingMapsArray()[1];
1010 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
1011 return op.emitOpError(
1012 "expected all dimensions to be either a LHS or a RHS dimension");
1013 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
1014 for (auto pair :
1015 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1016 VectorType v = pair.first;
1017 auto map = pair.second;
1018 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1019 unsigned pos = map.getDimPosition(idx);
1020 if (!extents[pos])
1021 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
1022 }
1023 }
1024 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
1025 return op.emitOpError("expected all dimensions to get an extent as "
1026 "either a LHS or a RHS dimension");
1027
1028 AffineMap resMap = op.getIndexingMapsArray()[2];
1029 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1030 /*symbolCount=*/0, extents, ctx);
1031 // Compose the resMap with the extentsMap, which is a constant map.
1032 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1033 assert(llvm::all_of(expectedMap.getResults(),
1034 llvm::IsaPred<AffineConstantExpr>) &&
1035 "expected constant extent along all dimensions.");
1036 // Extract the expected shape and build the type.
1037 auto expectedShape =
1038 llvm::map_to_vector<4>(expectedMap.getResults(), [](AffineExpr e) {
1039 return cast<AffineConstantExpr>(e).getValue();
1040 });
1041 auto expected =
1042 VectorType::get(expectedShape, resVectorType.getElementType(),
1043 resVectorType.getScalableDims());
1044 if (resVectorType != expected || accVectorType != expected)
1045 return op.emitOpError(
1046 "invalid accumulator/result vector shape, expected: ")
1047 << expected;
1048 }
1049 return success();
1050}
1051
1052LogicalResult ContractionOp::verify() {
1053 VectorType lhsType = getLhsType();
1054 VectorType rhsType = getRhsType();
1055 Type accType = getAccType();
1056 Type resType = getResultType();
1057
1058 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1059 if (!lhsType.getElementType().isSignlessInteger())
1060 return emitOpError("only supports signless integer types");
1061 }
1062
1063 // Verify that an indexing map was specified for each vector operand.
1064 if (getIndexingMapsArray().size() != 3)
1065 return emitOpError("expected an indexing map for each vector operand");
1066
1067 // Verify that each index map has 'numIterators' inputs, no symbols, and
1068 // that the number of map outputs equals the rank of its associated
1069 // vector operand.
1070 unsigned numIterators = getIteratorTypes().getValue().size();
1071 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1072 auto index = it.index();
1073 auto map = it.value();
1074 if (map.getNumSymbols() != 0)
1075 return emitOpError("expected indexing map ")
1076 << index << " to have no symbols";
1077 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1078 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1079 // Verify that the map has the right number of inputs, outputs, and indices.
1080 // This also correctly accounts for (..) -> () for rank-0 results.
1081 if (map.getNumDims() != numIterators)
1082 return emitOpError("expected indexing map ")
1083 << index << " to have " << numIterators << " number of inputs";
1084 if (map.getNumResults() != rank)
1085 return emitOpError("expected indexing map ")
1086 << index << " to have " << rank << " number of outputs";
1087 if (!map.isProjectedPermutation())
1088 return emitOpError("expected indexing map ")
1089 << index << " to be a projected permutation of its inputs";
1090 }
1091
1092 auto contractingDimMap = getContractingDimMap();
1093 auto batchDimMap = getBatchDimMap();
1094
1095 // Verify at least one contracting dimension pair was specified.
1096 if (contractingDimMap.empty())
1097 return emitOpError("expected at least one contracting dimension pair");
1098
1099 // Verify contracting dimension map was properly constructed.
1100 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1101 return emitOpError("invalid contracting dimension map");
1102
1103 // Verify batch dimension map was properly constructed.
1104 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1105 return emitOpError("invalid batch dimension map");
1106
1107 // Verify 'accType' and 'resType' shape.
1108 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1109 contractingDimMap, batchDimMap)))
1110 return failure();
1111
1112 if (!getKindAttr()) {
1113 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
1114 "'vector.kind<add>')");
1115 }
1116
1117 // Verify supported combining kind.
1118 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1119 auto elementType = vectorType ? vectorType.getElementType() : resType;
1120 if (!isSupportedCombiningKind(getKind(), elementType))
1121 return emitOpError("unsupported contraction type");
1122
1123 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1124 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1125}
1126
1127// MaskableOpInterface methods.
1128
1129/// Returns the mask type expected by this operation. Mostly used for
1130/// verification purposes. It requires the operation to be vectorized."
1131Type ContractionOp::getExpectedMaskType() {
1132 auto indexingMaps = this->getIndexingMapsArray();
1133 AffineMap lhsIdxMap = indexingMaps[0];
1134 AffineMap rhsIdxMap = indexingMaps[1];
1135 VectorType lhsType = this->getLhsType();
1136 VectorType rhsType = this->getRhsType();
1137
1138 unsigned numVecDims = lhsIdxMap.getNumDims();
1139 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1140 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1141
1142 // Using the information in the indexing maps, extract the size of each
1143 // dimension in the vector.contract operation from the two input operands.
1144 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1145 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1146 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1147 lhsType.getScalableDims()[dimIdx];
1148 }
1149 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1150 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1151 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1152 rhsType.getScalableDims()[dimIdx];
1153 }
1154
1155 assert(ShapedType::isStaticShape(maskShape) &&
1156 "Mask shape couldn't be computed");
1157
1158 return VectorType::get(maskShape,
1159 IntegerType::get(lhsType.getContext(), /*width=*/1),
1160 maskShapeScalableDims);
1161}
1162
1163SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1164 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1165 getIteratorTypesAttrName(), getKindAttrName(),
1166 getFastmathAttrName()};
1167}
1168
1170 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1171 if (targetExpr == map.getResult(i))
1172 return i;
1173 return -1;
1174}
1175
1176static std::vector<std::pair<int64_t, int64_t>>
1177getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1178 IteratorType targetIteratorType, MLIRContext *context) {
1179 std::vector<std::pair<int64_t, int64_t>> dimMap;
1180 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1181 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1182 if (iteratorType != targetIteratorType)
1183 continue;
1184 // Search lhs/rhs map results for 'targetExpr'.
1185 auto targetExpr = getAffineDimExpr(it.index(), context);
1186 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1187 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1188 if (lhsDim >= 0 && rhsDim >= 0)
1189 dimMap.emplace_back(lhsDim, rhsDim);
1190 }
1191 return dimMap;
1192}
1193
1194void ContractionOp::getIterationBounds(
1195 SmallVectorImpl<int64_t> &iterationBounds) {
1196 auto lhsShape = getLhsType().getShape();
1197 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1198 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1199 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1200 // Search lhs/rhs map results for 'targetExpr'.
1201 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1202 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1203 if (iteratorType == IteratorType::reduction) {
1204 // Get reduction dim size from lhs shape (same size in rhsShape).
1205 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1206 assert(lhsDimIndex >= 0);
1207 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1208 continue;
1209 }
1210 // Get parallel dimension size from result shape.
1211 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1212 assert(resDimIndex >= 0);
1213 assert(resVectorType != nullptr);
1214 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1215 }
1216}
1217
1218void ContractionOp::getIterationIndexMap(
1219 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1220 unsigned numMaps = getIndexingMapsArray().size();
1221 iterationIndexMap.resize(numMaps);
1222 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1223 auto index = it.index();
1224 auto map = it.value();
1225 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1226 auto dim = cast<AffineDimExpr>(map.getResult(i));
1227 iterationIndexMap[index][dim.getPosition()] = i;
1228 }
1229 }
1230}
1231
1232std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1233 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1234 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1235 getContext());
1236}
1237
1238std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1239 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1240 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1241 getContext());
1242}
1243
1244std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1246 getIterationBounds(shape);
1247 return shape;
1248}
1249
1250/// Return a fused vector::ContractionOp which represents a patterns such as:
1251///
1252/// ```mlir
1253/// %c0 = vector.constant 0: ...
1254/// %c = vector.contract %a, %b, %c0: ...
1255/// %e = add %c, %d: ...
1256/// ```
1257///
1258/// by:
1259///
1260/// ```mlir
1261/// %e = vector.contract %a, %b, %d: ...
1262/// ```
1263///
1264/// Return null if the canonicalization does not apply.
1265// TODO: This should be a folding of Add into Contract in core but while they
1266// live in different dialects, it is not possible without unnatural
1267// dependencies.
1268template <typename AddOpType>
1269struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1270 using OpRewritePattern<AddOpType>::OpRewritePattern;
1271
1272 LogicalResult matchAndRewrite(AddOpType addOp,
1273 PatternRewriter &rewriter) const override {
1274 auto canonicalize = [&](Value maybeContraction,
1275 Value otherOperand) -> vector::ContractionOp {
1276 vector::ContractionOp contractionOp =
1277 dyn_cast_or_null<vector::ContractionOp>(
1278 maybeContraction.getDefiningOp());
1279 if (!contractionOp)
1280 return vector::ContractionOp();
1281 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1282 contractionOp.getAcc().getDefiningOp())) {
1283 if (maybeZero.getValue() ==
1284 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1285 IRMapping bvm;
1286 bvm.map(contractionOp.getAcc(), otherOperand);
1287 auto newContraction =
1288 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1289 rewriter.replaceOp(addOp, newContraction.getResult());
1290 return newContraction;
1291 }
1292 }
1293 return vector::ContractionOp();
1294 };
1295
1296 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1297 vector::ContractionOp contract = canonicalize(a, b);
1298 contract = contract ? contract : canonicalize(b, a);
1299 return contract ? success() : failure();
1300 }
1301};
1302
1303void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1304 MLIRContext *context) {
1307}
1308
1309// Returns `true` if `index` is either within [0, maxIndex) or equal to
1310// `poisonValue`.
1312 int64_t maxIndex) {
1313 return index == poisonValue || (index >= 0 && index < maxIndex);
1314}
1315
1316//===----------------------------------------------------------------------===//
1317// ExtractOp
1318//===----------------------------------------------------------------------===//
1319
1320void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1321 SetIntRangeFn setResultRanges) {
1322 setResultRanges(getResult(), argRanges.front());
1323}
1324
1325void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1326 Value source) {
1327 auto vectorTy = cast<VectorType>(source.getType());
1328 build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1329}
1330
1331void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1332 Value source, int64_t position) {
1333 build(builder, result, source, ArrayRef<int64_t>{position});
1334}
1335
1336void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1337 Value source, OpFoldResult position) {
1338 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1339}
1340
1341void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1342 Value source, ArrayRef<int64_t> position) {
1343 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1344 builder.getDenseI64ArrayAttr(position));
1345}
1346
1347void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1348 Value source, ArrayRef<OpFoldResult> position) {
1349 SmallVector<int64_t> staticPos;
1350 SmallVector<Value> dynamicPos;
1351 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1352 build(builder, result, source, dynamicPos,
1353 builder.getDenseI64ArrayAttr(staticPos));
1354}
1355
1356LogicalResult
1357ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1358 ExtractOp::Adaptor adaptor,
1359 SmallVectorImpl<Type> &inferredReturnTypes) {
1360 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1361 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1362 vectorType.getRank()) {
1363 inferredReturnTypes.push_back(vectorType.getElementType());
1364 } else {
1365 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1366 vectorType.getRank());
1367 inferredReturnTypes.push_back(VectorType::get(
1368 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1369 vectorType.getScalableDims().drop_front(n)));
1370 }
1371 return success();
1372}
1373
1374LogicalResult vector::ExtractOp::verify() {
1375 if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1376 if (resTy.getRank() == 0)
1377 return emitError(
1378 "expected a scalar instead of a 0-d vector as the result type");
1379
1380 // Note: This check must come before getMixedPosition() to prevent a crash.
1381 auto dynamicMarkersCount =
1382 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1383 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1384 return emitOpError(
1385 "mismatch between dynamic and static positions (kDynamic marker but no "
1386 "corresponding dynamic position) -- this can only happen due to an "
1387 "incorrect fold/rewrite");
1388 auto position = getMixedPosition();
1389 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1390 return emitOpError(
1391 "expected position attribute of rank no greater than vector rank");
1392 for (auto [idx, pos] : llvm::enumerate(position)) {
1393 if (auto attr = dyn_cast<Attribute>(pos)) {
1394 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1396 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1397 return emitOpError("expected position attribute #")
1398 << (idx + 1)
1399 << " to be a non-negative integer smaller than the "
1400 "corresponding vector dimension or poison (-1)";
1401 }
1402 }
1403 }
1404 return success();
1405}
1406
1407template <typename IntType>
1409 return llvm::map_to_vector<4>(
1410 arrayAttr.getAsRange<IntegerAttr>(),
1411 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1412}
1413
1414/// Fold the result of chains of ExtractOp in place by simply concatenating the
1415/// positions.
1416static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1417 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1418 return failure();
1419
1420 // TODO: Canonicalization for dynamic position not implemented yet.
1421 if (extractOp.hasDynamicPosition())
1422 return failure();
1423
1424 SmallVector<int64_t> globalPosition;
1425 ExtractOp currentOp = extractOp;
1426 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1427 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1428 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1429 currentOp = nextOp;
1430 // TODO: Canonicalization for dynamic position not implemented yet.
1431 if (currentOp.hasDynamicPosition())
1432 return failure();
1433 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1434 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1435 }
1436 extractOp.setOperand(0, currentOp.getSource());
1437 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1438 OpBuilder b(extractOp.getContext());
1439 std::reverse(globalPosition.begin(), globalPosition.end());
1440 extractOp.setStaticPosition(globalPosition);
1441 return success();
1442}
1443
1444namespace {
1445/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1446/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1447/// Compose TransposeOp permutations as we walk back.
1448/// This helper class keeps an updated extraction position `extractPosition`
1449/// with extra trailing sentinels.
1450/// The sentinels encode the internal transposition status of the result vector.
1451/// As we iterate, extractPosition is permuted and updated.
1452class ExtractFromInsertTransposeChainState {
1453public:
1454 ExtractFromInsertTransposeChainState(ExtractOp e);
1455
1456 /// Iterate over producing insert and transpose ops until we find a fold.
1457 Value fold();
1458
1459private:
1460 /// Return true if the vector at position `a` is contained within the vector
1461 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1462 /// is a prefix of `b`.
1463 template <typename ContainerA, typename ContainerB>
1464 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1465 return a.size() <= b.size() &&
1466 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1467 }
1468
1469 /// Return true if the vector at position `a` intersects the vector at
1470 /// position `b`. Under insert/extract semantics, this is the same as equality
1471 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1472 /// Comparison is on the common prefix (i.e. zip).
1473 template <typename ContainerA, typename ContainerB>
1474 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1475 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1476 if (elemA < 0 || elemB < 0)
1477 continue;
1478 if (elemA != elemB)
1479 return false;
1480 }
1481 return true;
1482 }
1483
1484 /// Folding is only possible in the absence of an internal permutation in the
1485 /// result vector.
1486 bool canFold() {
1487 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1488 }
1489
1490 // Helper to get the next defining op of interest.
1491 void updateStateForNextIteration(Value v) {
1492 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1493 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1494 };
1495
1496 // Case 1. If we hit a transpose, just compose the map and iterate.
1497 // Invariant: insert + transpose do not change rank, we can always compose.
1498 LogicalResult handleTransposeOp();
1499
1500 // Case 2: the insert position matches extractPosition exactly, early return.
1501 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1502
1503 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1504 /// portion of the source of the insert.
1505 /// Example:
1506 /// ```
1507 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1508 /// // extractPosition == [1, 2, 3]
1509 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1510 /// // can fold to vector.extract %source[0, 3]
1511 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1512 /// ```
1513 /// To traverse through %source, we need to set the leading dims to 0 and
1514 /// drop the extra leading dims.
1515 /// This method updates the internal state.
1516 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1517
1518 /// Try to fold in place to extract(source, extractPosition) and return the
1519 /// folded result. Return null if folding is not possible (e.g. due to an
1520 /// internal transposition in the result).
1521 Value tryToFoldExtractOpInPlace(Value source);
1522
1523 ExtractOp extractOp;
1524 int64_t vectorRank;
1525 int64_t extractedRank;
1526
1527 InsertOp nextInsertOp;
1528 TransposeOp nextTransposeOp;
1529
1530 /// Sentinel values that encode the internal permutation status of the result.
1531 /// They are set to (-1, ... , -k) at the beginning and appended to
1532 /// `extractPosition`.
1533 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1534 /// ensure that there is no internal transposition.
1535 /// Internal transposition cannot be accounted for with a folding pattern.
1536 // TODO: We could relax the internal transposition with an extra transposition
1537 // operation in a future canonicalizer.
1538 SmallVector<int64_t> sentinels;
1539 SmallVector<int64_t> extractPosition;
1540};
1541} // namespace
1542
1543ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1544 ExtractOp e)
1545 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1546 extractedRank(extractOp.getNumIndices()) {
1547 assert(vectorRank >= extractedRank && "Extracted position overflow");
1548 sentinels.reserve(vectorRank - extractedRank);
1549 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1550 sentinels.push_back(-(i + 1));
1551 extractPosition.assign(extractOp.getStaticPosition().begin(),
1552 extractOp.getStaticPosition().end());
1553 llvm::append_range(extractPosition, sentinels);
1554}
1555
1556// Case 1. If we hit a transpose, just compose the map and iterate.
1557// Invariant: insert + transpose do not change rank, we can always compose.
1558LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1559 // TODO: Canonicalization for dynamic position not implemented yet.
1560 if (extractOp.hasDynamicPosition())
1561 return failure();
1562
1563 if (!nextTransposeOp)
1564 return failure();
1566 nextTransposeOp.getPermutation(), extractOp.getContext()));
1568 return success();
1569}
1570
1571// Case 2: the insert position matches extractPosition exactly, early return.
1572LogicalResult
1573ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1574 Value &res) {
1575 // TODO: Canonicalization for dynamic position not implemented yet.
1576 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1577 return failure();
1578
1579 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1580 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1581 return failure();
1582 // Case 2.a. early-exit fold.
1583 res = nextInsertOp.getValueToStore();
1584 // Case 2.b. if internal transposition is present, canFold will be false.
1585 return success(canFold());
1586}
1587
1588/// Case 3: if inserted position is a prefix of extractPosition,
1589/// extract a portion of the source of the insertion.
1590/// This method updates the internal state.
1591LogicalResult
1592ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1593 // TODO: Canonicalization for dynamic position not implemented yet.
1594 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1595 return failure();
1596
1597 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1598 if (!isContainedWithin(insertedPos, extractPosition))
1599 return failure();
1600 // Set leading dims to zero.
1601 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1602 // Drop extra leading dims.
1603 extractPosition.erase(extractPosition.begin(),
1604 extractPosition.begin() + insertedPos.size());
1605 extractedRank = extractPosition.size() - sentinels.size();
1606 // Case 3.a. early-exit fold (break and delegate to post-while path).
1607 res = nextInsertOp.getValueToStore();
1608 // Case 3.b. if internal transposition is present, canFold will be false.
1609 return success();
1610}
1611
1612/// Try to fold in place to extract(source, extractPosition) and return the
1613/// folded result. Return null if folding is not possible (e.g. due to an
1614/// internal transposition in the result).
1615Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1616 Value source) {
1617 // TODO: Canonicalization for dynamic position not implemented yet.
1618 if (extractOp.hasDynamicPosition())
1619 return Value();
1620
1621 // If we can't fold (either internal transposition, or nothing to fold), bail.
1622 bool nothingToFold = (source == extractOp.getSource());
1623 if (nothingToFold || !canFold())
1624 return Value();
1625
1626 // Otherwise, fold by updating the op inplace and return its result.
1627 OpBuilder b(extractOp.getContext());
1628 extractOp.setStaticPosition(
1629 ArrayRef(extractPosition).take_front(extractedRank));
1630 extractOp.getSourceMutable().assign(source);
1631 return extractOp.getResult();
1632}
1633
1634/// Iterate over producing insert and transpose ops until we find a fold.
1635Value ExtractFromInsertTransposeChainState::fold() {
1636 // TODO: Canonicalization for dynamic position not implemented yet.
1637 if (extractOp.hasDynamicPosition())
1638 return Value();
1639
1640 Value valueToExtractFrom = extractOp.getSource();
1641 updateStateForNextIteration(valueToExtractFrom);
1642 while (nextInsertOp || nextTransposeOp) {
1643 // Case 1. If we hit a transpose, just compose the map and iterate.
1644 // Invariant: insert + transpose do not change rank, we can always compose.
1645 if (succeeded(handleTransposeOp())) {
1646 valueToExtractFrom = nextTransposeOp.getVector();
1647 updateStateForNextIteration(valueToExtractFrom);
1648 continue;
1649 }
1650
1651 Value result;
1652 // Case 2: the position match exactly.
1653 if (succeeded(handleInsertOpWithMatchingPos(result)))
1654 return result;
1655
1656 // Case 3: if the inserted position is a prefix of extractPosition, we can
1657 // just extract a portion of the source of the insert.
1658 if (succeeded(handleInsertOpWithPrefixPos(result)))
1659 return tryToFoldExtractOpInPlace(result);
1660
1661 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1662 // values. This is a more difficult case and we bail.
1663 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1664 if (isContainedWithin(extractPosition, insertedPos) ||
1665 intersectsWhereNonNegative(extractPosition, insertedPos))
1666 return Value();
1667
1668 // Case 5: No intersection, we forward the extract to insertOp.dest().
1669 valueToExtractFrom = nextInsertOp.getDest();
1670 updateStateForNextIteration(valueToExtractFrom);
1671 }
1672 // If after all this we can fold, go for it.
1673 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1674}
1675
1676/// Returns true if the operation has a 0-D vector type operand or result.
1678 auto hasZeroDimVectorType = [](Type type) -> bool {
1679 auto vecType = dyn_cast<VectorType>(type);
1680 return vecType && vecType.getRank() == 0;
1681 };
1682
1683 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1684 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1685}
1686
1687/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1688/// considered to be 'broadcastlike'.
1689static bool isBroadcastLike(Operation *op) {
1690 if (isa<BroadcastOp>(op))
1691 return true;
1692
1693 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1694 if (!shapeCast)
1695 return false;
1696
1697 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1698 // Checking that the destination shape has a prefix of 1s is not sufficient,
1699 // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1700 // is that the source shape is a suffix of the destination shape.
1701 VectorType srcType = shapeCast.getSourceVectorType();
1702 ArrayRef<int64_t> srcShape = srcType.getShape();
1703 uint64_t srcRank = srcType.getRank();
1704 ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1705 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1706}
1707
1708/// Fold extract(broadcast(X)) to either extract(X) or just X.
1709///
1710/// Example:
1711///
1712/// broadcast extract [1][2]
1713/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1714///
1715/// becomes
1716/// extract [1]
1717/// (3,4) -------------------------------------> (4)
1718///
1719///
1720/// The variable names used in this implementation correspond to the above
1721/// shapes as,
1722///
1723/// - (3, 4) is `input` shape.
1724/// - (2, 3, 4) is `broadcast` shape.
1725/// - (4) is `extract` shape.
1726///
1727/// This folding is possible when the suffix of `input` shape is the same as
1728/// `extract` shape.
1729static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1730
1731 Operation *defOp = extractOp.getSource().getDefiningOp();
1732 if (!defOp || !isBroadcastLike(defOp))
1733 return Value();
1734
1735 Value input = defOp->getOperand(0);
1736
1737 // Replace extract(broadcast(X)) with X
1738 if (extractOp.getType() == input.getType())
1739 return input;
1740
1741 // Get required types and ranks in the chain
1742 // input -> broadcast -> extract
1743 // (scalars are treated as rank-0).
1744 auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1745 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1746 unsigned inputRank = inputType ? inputType.getRank() : 0;
1747 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1748 unsigned extractRank = extractType ? extractType.getRank() : 0;
1749
1750 // Cannot do without the broadcast if overall the rank increases.
1751 if (extractRank > inputRank)
1752 return Value();
1753
1754 // The above condition guarantees that input is a vector.
1755 assert(inputType && "input must be a vector type because of previous checks");
1756 ArrayRef<int64_t> inputShape = inputType.getShape();
1757
1758 // In the case where there is a broadcast dimension in the suffix, it is not
1759 // possible to replace extract(broadcast(X)) with extract(X). Example:
1760 //
1761 // broadcast extract
1762 // (1) --------> (3,4) ------> (4)
1763 if (extractType &&
1764 extractType.getShape() != inputShape.take_back(extractRank))
1765 return Value();
1766
1767 // Replace extract(broadcast(X)) with extract(X).
1768 // First, determine the new extraction position.
1769 unsigned deltaOverall = inputRank - extractRank;
1770 unsigned deltaBroadcast = broadcastRank - inputRank;
1771 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1772 SmallVector<OpFoldResult> newPositions(deltaOverall);
1773 IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1774 for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1775 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1776 }
1777 auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1778 extractOp->setOperands(
1779 llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1780 extractOp.setStaticPosition(staticPos);
1781 return extractOp.getResult();
1782}
1783
1784/// Fold extractOp coming from ShuffleOp.
1785///
1786/// Example:
1787///
1788/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1789/// : vector<8xf32>, vector<8xf32>
1790/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1791/// ->
1792/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1793///
1794static Value foldExtractFromShuffle(ExtractOp extractOp) {
1795 // Dynamic positions are not folded as the resulting code would be more
1796 // complex than the input code.
1797 if (extractOp.hasDynamicPosition())
1798 return Value();
1799
1800 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1801 if (!shuffleOp)
1802 return Value();
1803
1804 // TODO: 0-D or multi-dimensional vectors not supported yet.
1805 if (shuffleOp.getResultVectorType().getRank() != 1)
1806 return Value();
1807
1808 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1809 auto shuffleMask = shuffleOp.getMask();
1810 int64_t extractIdx = extractOp.getStaticPosition()[0];
1811 int64_t shuffleIdx = shuffleMask[extractIdx];
1812
1813 // Find the shuffled vector to extract from based on the shuffle index.
1814 if (shuffleIdx < inputVecSize) {
1815 extractOp.setOperand(0, shuffleOp.getV1());
1816 extractOp.setStaticPosition({shuffleIdx});
1817 } else {
1818 extractOp.setOperand(0, shuffleOp.getV2());
1819 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1820 }
1821
1822 return extractOp.getResult();
1823}
1824
1825// Fold extractOp with source coming from ShapeCast op.
1826static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1827 // TODO: Canonicalization for dynamic position not implemented yet.
1828 if (extractOp.hasDynamicPosition())
1829 return Value();
1830
1831 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1832 if (!shapeCastOp)
1833 return Value();
1834
1835 // Get the nth dimension size starting from lowest dimension.
1836 auto getDimReverse = [](VectorType type, int64_t n) {
1837 return type.getShape().take_back(n + 1).front();
1838 };
1839 int64_t destinationRank =
1840 llvm::isa<VectorType>(extractOp.getType())
1841 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1842 : 0;
1843 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1844 return Value();
1845 if (destinationRank > 0) {
1846 auto destinationType =
1847 llvm::cast<VectorType>(extractOp.getResult().getType());
1848 for (int64_t i = 0; i < destinationRank; i++) {
1849 // The lowest dimension of the destination must match the lowest
1850 // dimension of the shapecast op source.
1851 // TODO: This case could be support in a canonicalization pattern.
1852 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1853 getDimReverse(destinationType, i))
1854 return Value();
1855 }
1856 }
1857 // Extract the strides associated with the extract op vector source. Then use
1858 // this to calculate a linearized position for the extract.
1859 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1860 std::reverse(extractedPos.begin(), extractedPos.end());
1862 int64_t stride = 1;
1863 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1864 strides.push_back(stride);
1865 stride *=
1866 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1867 }
1868
1869 int64_t position = linearize(extractedPos, strides);
1870 // Then extract the strides associated to the shapeCast op vector source and
1871 // delinearize the position using those strides.
1872 SmallVector<int64_t, 4> newStrides;
1873 int64_t numDimension =
1874 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1875 stride = 1;
1876 for (int64_t i = 0; i < numDimension; i++) {
1877 newStrides.push_back(stride);
1878 stride *=
1879 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1880 }
1881 std::reverse(newStrides.begin(), newStrides.end());
1882 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1883 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1884 OpBuilder b(extractOp.getContext());
1885 extractOp.setStaticPosition(newPosition);
1886 extractOp.setOperand(0, shapeCastOp.getSource());
1887 return extractOp.getResult();
1888}
1889
1890/// Fold an ExtractOp from ExtractStridedSliceOp.
1891static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1892 // TODO: Canonicalization for dynamic position not implemented yet.
1893 if (extractOp.hasDynamicPosition())
1894 return Value();
1895
1896 auto extractStridedSliceOp =
1897 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1898 if (!extractStridedSliceOp)
1899 return Value();
1900
1901 // 0-D vectors not supported.
1902 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1903 if (hasZeroDimVectors(extractStridedSliceOp))
1904 return Value();
1905
1906 // Return if 'extractStridedSliceOp' has non-unit strides.
1907 if (extractStridedSliceOp.hasNonUnitStrides())
1908 return Value();
1909
1910 // Trim offsets for dimensions fully extracted.
1911 auto sliceOffsets =
1912 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1913 while (!sliceOffsets.empty()) {
1914 size_t lastOffset = sliceOffsets.size() - 1;
1915 if (sliceOffsets.back() != 0 ||
1916 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1917 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1918 break;
1919 sliceOffsets.pop_back();
1920 }
1921 unsigned destinationRank = 0;
1922 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1923 destinationRank = vecType.getRank();
1924 // The dimensions of the result need to be untouched by the
1925 // extractStridedSlice op.
1926 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1927 sliceOffsets.size())
1928 return Value();
1929
1930 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1931 assert(extractedPos.size() >= sliceOffsets.size());
1932 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1933 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1934 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1935
1936 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1937 OpBuilder b(extractOp.getContext());
1938 extractOp.setStaticPosition(extractedPos);
1939 return extractOp.getResult();
1940}
1941
1942/// Fold extract_op fed from a chain of insertStridedSlice ops.
1943static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1944 // TODO: Canonicalization for dynamic position not implemented yet.
1945 if (extractOp.hasDynamicPosition())
1946 return Value();
1947
1948 int64_t destinationRank =
1949 llvm::isa<VectorType>(extractOp.getType())
1950 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1951 : 0;
1952 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1953 if (!insertOp)
1954 return Value();
1955
1956 // 0-D vectors not supported.
1957 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1958 if (hasZeroDimVectors(insertOp))
1959 return Value();
1960
1961 while (insertOp) {
1962 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1963 insertOp.getSourceVectorType().getRank();
1964 if (destinationRank > insertOp.getSourceVectorType().getRank())
1965 return Value();
1966 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1967 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1968
1969 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1970 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1971 }))
1972 return Value();
1973 bool disjoint = false;
1974 SmallVector<int64_t, 4> offsetDiffs;
1975 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1976 int64_t start = insertOffsets[dim];
1977 int64_t size =
1978 (dim < insertRankDiff)
1979 ? 1
1980 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1981 int64_t end = start + size;
1982 int64_t offset = extractOffsets[dim];
1983 // Check if the start of the extract offset is in the interval inserted.
1984 if (start <= offset && offset < end) {
1985 if (dim >= insertRankDiff)
1986 offsetDiffs.push_back(offset - start);
1987 continue;
1988 }
1989 disjoint = true;
1990 break;
1991 }
1992 // The extract element chunk overlap with the vector inserted.
1993 if (!disjoint) {
1994 // If any of the inner dimensions are only partially inserted we have a
1995 // partial overlap.
1996 int64_t srcRankDiff =
1997 insertOp.getSourceVectorType().getRank() - destinationRank;
1998 for (int64_t i = 0; i < destinationRank; i++) {
1999 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
2000 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
2001 insertRankDiff))
2002 return Value();
2003 }
2004 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2005 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2006 OpBuilder b(extractOp.getContext());
2007 extractOp.setStaticPosition(offsetDiffs);
2008 return extractOp.getResult();
2009 }
2010 // If the chunk extracted is disjoint from the chunk inserted, keep
2011 // looking in the insert chain.
2012 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2013 }
2014 return Value();
2015}
2016
2017/// Try to fold the extraction of a scalar from a vector defined by
2018/// vector.from_elements. E.g.:
2019///
2020/// %0 = vector.from_elements %a, %b : vector<2xf32>
2021/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2022/// ==> fold to %a
2023static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2024 // Dynamic extractions cannot be folded.
2025 if (extractOp.hasDynamicPosition())
2026 return {};
2027
2028 // Look for extract(from_elements).
2029 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2030 if (!fromElementsOp)
2031 return {};
2032
2033 // Scalable vectors are not supported.
2034 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2035 if (vecType.isScalable())
2036 return {};
2037
2038 // Only extractions of scalars are supported.
2039 int64_t rank = vecType.getRank();
2040 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2041 if (extractOp.getType() != vecType.getElementType())
2042 return {};
2043 assert(static_cast<int64_t>(indices.size()) == rank &&
2044 "unexpected number of indices");
2045
2046 // Compute flattened/linearized index and fold to operand.
2047 int flatIndex = 0;
2048 int stride = 1;
2049 for (int i = rank - 1; i >= 0; --i) {
2050 flatIndex += indices[i] * stride;
2051 stride *= vecType.getDimSize(i);
2052 }
2053 return fromElementsOp.getElements()[flatIndex];
2054}
2055
2056/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2057/// then fold it.
2058template <typename OpType, typename AdaptorType>
2059static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2060 SmallVectorImpl<Value> &operands) {
2061 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2062 OperandRange dynamicPosition = op.getDynamicPosition();
2063 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2065 if constexpr (std::is_same_v<OpType, ExtractOp>)
2066 vectorShape = op.getSourceVectorType().getShape();
2067 else
2068 vectorShape = op.getDestVectorType().getShape();
2069
2070 // If the dynamic operands is empty, it is returned directly.
2071 if (!dynamicPosition.size())
2072 return {};
2073
2074 // `index` is used to iterate over the `dynamicPosition`.
2075 unsigned index = 0;
2076
2077 // `opChange` is a flag. If it is true, it means to update `op` in place.
2078 bool opChange = false;
2079 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2080 if (ShapedType::isStatic(staticPosition[i]))
2081 continue;
2082 Attribute positionAttr = dynamicPositionAttr[index];
2083 Value position = dynamicPosition[index++];
2084 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2085 int64_t value = attr.getInt();
2086 // Do not fold if the value is out of bounds (-1 signifies a poison
2087 // value rather than OOB index).
2088 if (value >= -1 && value < vectorShape[i]) {
2089 staticPosition[i] = attr.getInt();
2090 opChange = true;
2091 continue;
2092 }
2093 }
2094 operands.push_back(position);
2095 }
2096
2097 if (opChange) {
2098 op.setStaticPosition(staticPosition);
2099 op.getOperation()->setOperands(operands);
2100 // Return the original result to indicate an in-place folding happened.
2101 return op.getResult();
2102 }
2103 return {};
2104}
2105
2106/// Fold an insert or extract operation into an poison value when a poison index
2107/// is found at any dimension of the static position.
2109 ArrayRef<int64_t> staticPos,
2110 int64_t poisonVal) {
2111 if (!is_contained(staticPos, poisonVal))
2112 return {};
2113
2114 return ub::PoisonAttr::get(context);
2115}
2116
2117/// Fold a vector extract from is a poison source.
2119 if (matchPattern(srcAttr, ub::m_Poison()))
2120 return srcAttr;
2121
2122 return {};
2123}
2124
2125/// Fold a vector extract extracting from a DenseElementsAttr.
2127 Attribute srcAttr) {
2128 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2129 if (!denseAttr) {
2130 return {};
2131 }
2132
2133 if (denseAttr.isSplat()) {
2134 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2135 if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2136 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2137 return newAttr;
2138 }
2139
2140 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2141 if (vecTy.isScalable())
2142 return {};
2143
2144 if (extractOp.hasDynamicPosition()) {
2145 return {};
2146 }
2147
2148 // Materializing subsets of a large constant array can generally lead to
2149 // explosion in IR size because of different combination of subsets that
2150 // can exist. However, vector.extract is a restricted form of subset
2151 // extract where you can only extract non-overlapping (or the same) subset for
2152 // a given rank of the subset. Because of this property, the IR size can only
2153 // increase at most by `rank * size(array)` from a single constant array being
2154 // extracted by multiple extracts.
2155
2156 // Calculate the linearized position of the continuous chunk of elements to
2157 // extract.
2158 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2159 copy(extractOp.getStaticPosition(), completePositions.begin());
2160 int64_t startPos =
2161 linearize(completePositions, computeStrides(vecTy.getShape()));
2162 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2163
2164 TypedAttr newAttr;
2165 if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2166 SmallVector<Attribute> elementValues(
2167 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2168 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2169 } else {
2170 newAttr = *denseValuesBegin;
2171 }
2172
2173 return newAttr;
2174}
2175
2176OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2177 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2178 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2179 // mismatch).
2180 if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2181 return getSource();
2182 if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2183 return res;
2184 // Fold `arith.constant` indices into the `vector.extract` operation.
2185 // Do not stop here as this fold may enable subsequent folds that require
2186 // constant indices.
2187 SmallVector<Value> operands = {getSource()};
2188 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2189
2190 if (auto res = foldPoisonIndexInsertExtractOp(
2191 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2192 return res;
2193 if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2194 return res;
2195 if (succeeded(foldExtractOpFromExtractChain(*this)))
2196 return getResult();
2197 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2198 return res;
2199 if (auto res = foldExtractFromBroadcast(*this))
2200 return res;
2201 if (auto res = foldExtractFromShuffle(*this))
2202 return res;
2203 if (auto res = foldExtractFromShapeCast(*this))
2204 return res;
2205 if (auto val = foldExtractFromExtractStrided(*this))
2206 return val;
2207 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2208 return val;
2209 if (auto val = foldScalarExtractFromFromElements(*this))
2210 return val;
2211
2212 return inplaceFolded;
2213}
2214
2215namespace {
2216
2217// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2218class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2219public:
2220 using Base::Base;
2221
2222 LogicalResult matchAndRewrite(ExtractOp extractOp,
2223 PatternRewriter &rewriter) const override {
2224
2225 Operation *defOp = extractOp.getSource().getDefiningOp();
2226 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2227 if (!defOp || !isBroadcastLike(defOp) || !outType)
2228 return failure();
2229
2230 Value source = defOp->getOperand(0);
2231 if (isBroadcastableTo(source.getType(), outType) !=
2232 BroadcastableToResult::Success)
2233 return failure();
2234
2235 rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2236 return success();
2237 }
2238};
2239
2240// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2241class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2242public:
2243 using Base::Base;
2244
2245 LogicalResult matchAndRewrite(ExtractOp extractOp,
2246 PatternRewriter &rewriter) const override {
2247 auto createMaskOp =
2248 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2249 if (!createMaskOp)
2250 return failure();
2251
2252 VectorType extractedMaskType =
2253 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2254
2255 if (!extractedMaskType)
2256 return failure();
2257
2258 auto maskOperands = createMaskOp.getOperands();
2259 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2260 VectorType maskType = createMaskOp.getVectorType();
2261
2262 bool containsUnknownDims = false;
2263 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2264
2265 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2266 dimIdx++) {
2267 int64_t pos = extractOpPos[dimIdx];
2268 Value operand = maskOperands[dimIdx];
2269 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2270 if (!constantOp) {
2271 // Bounds of this dim unknown.
2272 containsUnknownDims = true;
2273 continue;
2274 }
2275
2276 int64_t createMaskBound =
2277 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2278
2279 if (pos != ShapedType::kDynamic) {
2280 // If any position is outside the range from the `create_mask`, then the
2281 // extracted mask will be all-false.
2282 allFalse |= pos >= createMaskBound;
2283 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2284 // This dim is not all-true and since this is a dynamic index we don't
2285 // know if the extraction is within the true or false region.
2286 // Note: Zero dims have already handled via getMaskFormat().
2287 containsUnknownDims = true;
2288 }
2289 }
2290
2291 if (allFalse) {
2292 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2293 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2294 } else if (!containsUnknownDims) {
2295 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2296 extractOp, extractedMaskType,
2297 maskOperands.drop_front(extractOpPos.size()));
2298 } else {
2299 return failure();
2300 }
2301 return success();
2302 }
2303};
2304
2305// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
2306class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
2307public:
2308 using Base::Base;
2309
2310 LogicalResult matchAndRewrite(ExtractOp extractOp,
2311 PatternRewriter &rewriter) const override {
2312 auto constantMaskOp =
2313 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2314 if (!constantMaskOp)
2315 return failure();
2316
2317 Type resultType = extractOp.getResult().getType();
2318 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2319
2320 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2321 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2322
2323 VectorType maskType = constantMaskOp.getVectorType();
2324
2325 // Check if any extracted position is outside the mask bounds.
2326 for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2327 int64_t pos = extractOpPos[dimIdx];
2328 if (pos == ShapedType::kDynamic) {
2329 // If the dim is all-true, a dynamic index is fine — any position
2330 // is within the masked region.
2331 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2332 continue;
2333 // Otherwise we don't know if the position is inside or outside of
2334 // the masked area, so bail out.
2335 return failure();
2336 }
2337
2338 // If the position is statically outside of the masked area, the result
2339 // will be all-false.
2340 if (pos >= maskDimSizes[dimIdx]) {
2341 if (extractedMaskType) {
2342 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2343 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2344 } else {
2345 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2346 extractOp, rewriter.getIntegerAttr(resultType, false));
2347 }
2348 return success();
2349 }
2350 }
2351
2352 // All positions are within the mask bounds.
2353 if (extractedMaskType) {
2354 // Vector result: the result is a constant_mask with the remaining
2355 // dimensions.
2356 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
2357 extractOp, extractedMaskType,
2358 maskDimSizes.drop_front(extractOpPos.size()));
2359 } else {
2360 // Scalar result: all positions are within the masked region, so the
2361 // result is true.
2362 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2363 extractOp, rewriter.getIntegerAttr(resultType, true));
2364 }
2365 return success();
2366 }
2367};
2368
2369// Folds extract(shape_cast(..)) into shape_cast when the total element count
2370// does not change.
2371LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2372 PatternRewriter &rewriter) {
2373 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2374 if (!castOp)
2375 return failure();
2376
2377 VectorType sourceType = castOp.getSourceVectorType();
2378 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2379 if (!targetType)
2380 return failure();
2381
2382 if (sourceType.getNumElements() != targetType.getNumElements())
2383 return failure();
2384
2385 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2386 castOp.getSource());
2387 return success();
2388}
2389
2390/// Try to canonicalize the extraction of a subvector from a vector defined by
2391/// vector.from_elements. E.g.:
2392///
2393/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2394/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2395/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2396LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2397 PatternRewriter &rewriter) {
2398 // Dynamic positions are not supported.
2399 if (extractOp.hasDynamicPosition())
2400 return failure();
2401
2402 // Scalar extracts are handled by the folder.
2403 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2404 if (!resultType)
2405 return failure();
2406
2407 // Look for extracts from a from_elements op.
2408 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2409 if (!fromElementsOp)
2410 return failure();
2411 VectorType inputType = fromElementsOp.getType();
2412
2413 // Scalable vectors are not supported.
2414 if (resultType.isScalable() || inputType.isScalable())
2415 return failure();
2416
2417 // Compute the position of first extracted element and flatten/linearize the
2418 // position.
2419 SmallVector<int64_t> firstElementPos =
2420 llvm::to_vector(extractOp.getStaticPosition());
2421 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2422 int flatIndex = 0;
2423 int stride = 1;
2424 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2425 flatIndex += firstElementPos[i] * stride;
2426 stride *= inputType.getDimSize(i);
2427 }
2428
2429 // Replace the op with a smaller from_elements op.
2430 rewriter.replaceOpWithNewOp<FromElementsOp>(
2431 extractOp, resultType,
2432 fromElementsOp.getElements().slice(flatIndex,
2433 resultType.getNumElements()));
2434 return success();
2435}
2436
2437/// Replace `vector.extract` with `vector.shape_cast`.
2438///
2439/// BEFORE:
2440/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2441/// AFTER:
2442/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2443///
2444/// The canonical form of vector operations that reshape vectors is shape_cast.
2445struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2446 using Base::Base;
2447 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2448 PatternRewriter &rewriter) const override {
2449 VectorType sourceType = extractOp.getSourceVectorType();
2450 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2451 if (!outType)
2452 return failure();
2453
2454 if (sourceType.getNumElements() != outType.getNumElements())
2455 return rewriter.notifyMatchFailure(
2456 extractOp, "extract to vector with fewer elements");
2457
2458 // Negative values in `position` means that the extacted value is poison.
2459 // There is a vector.extract folder for this.
2460 if (llvm::any_of(extractOp.getMixedPosition(),
2461 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2462 return rewriter.notifyMatchFailure(extractOp,
2463 "leaving for extract poison folder");
2464
2465 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2466 extractOp.getSource());
2467
2468 return success();
2469 }
2470};
2471
2472/// Folds vector.extract from vector.insert when the extract position is a
2473/// prefix of the insert position and the remaining (un-indexed) dimensions
2474/// of the extracted sub-vector are all size 1. In that case the extracted
2475/// value is fully determined by the inserted value.
2476///
2477/// Examples:
2478/// %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
2479/// %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
2480/// folds to:
2481/// %ext = vector.broadcast %s : f32 to vector<1xf32>
2482///
2483/// %ins = vector.insert %s, %v [0, 0] : vector<1xf32> into vector<16x1x1xf32>
2484// %ext = vector.extract %ins [0] : vector<1x1xf32> from vector<16x1x1xf32>
2485/// folds to:
2486/// %ext = vector.shape_cast %arg0 : vector<1xf32> to vector<1x1xf32>
2487struct FoldExtractFromInsertUnitDim final
2488 : OpRewritePattern<vector::ExtractOp> {
2489 using Base::Base;
2490
2491 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2492 PatternRewriter &rewriter) const override {
2493 if (extractOp.hasDynamicPosition())
2494 return failure();
2495
2496 auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
2497 if (!insertOp || insertOp.hasDynamicPosition())
2498 return failure();
2499
2500 ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
2501 ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
2502
2503 // The extract position must be a strict prefix of the insert position.
2504 if (extractPos.size() >= insertPos.size() ||
2505 extractPos != insertPos.take_front(extractPos.size()))
2506 return failure();
2507
2508 // The remaining dimensions (those not indexed by the extract) must all
2509 // be size 1 in the source vector type. This guarantees that the inserted
2510 // value fully determines the extracted sub-vector.
2511 auto srcVecType = extractOp.getSourceVectorType();
2512 for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
2513 if (srcVecType.getDimSize(i) != 1)
2514 return failure();
2515
2516 Value inserted = insertOp.getValueToStore();
2517 Type extractedType = extractOp.getResult().getType();
2518 if (isa<VectorType>(inserted.getType())) {
2519 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, extractedType,
2520 inserted);
2521 } else {
2522 // The inserted value fully determines the extracted sub-vector; broadcast
2523 // it to the extracted type.
2524 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2525 extractOp, extractOp.getResult().getType(),
2526 insertOp.getValueToStore());
2527 }
2528 return success();
2529 }
2530};
2531
2532} // namespace
2533
2534void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2535 MLIRContext *context) {
2536 results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2537 ExtractOpFromConstantMask, ExtractToShapeCast,
2538 FoldExtractFromInsertUnitDim>(context);
2539 results.add(foldExtractFromShapeCastToShapeCast);
2540 results.add(foldExtractFromFromElements);
2541}
2542
2544 SmallVectorImpl<int64_t> &results) {
2545 for (auto attr : arrayAttr)
2546 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2547}
2548
2549//===----------------------------------------------------------------------===//
2550// FmaOp
2551//===----------------------------------------------------------------------===//
2552
2553std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2554 return llvm::to_vector<4>(getVectorType().getShape());
2555}
2556
2557//===----------------------------------------------------------------------===//
2558// ToElementsOp
2559//===----------------------------------------------------------------------===//
2560
2561/// Returns true if all the `operands` are defined by `defOp`.
2562/// Otherwise, returns false.
2563static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2564 if (operands.empty())
2565 return false;
2566
2567 return llvm::all_of(operands, [&](Value operand) {
2568 Operation *currentDef = operand.getDefiningOp();
2569 return currentDef == defOp;
2570 });
2571}
2572
2573/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2574/// (%e0, %e1, ...). For example:
2575///
2576/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2577/// %1:3 = vector.to_elements %0 : vector<3xf32>
2578/// user_op %1#0, %1#1, %1#2
2579///
2580/// becomes:
2581///
2582/// user_op %a, %b, %c
2583///
2584static LogicalResult
2585foldToElementsFromElements(ToElementsOp toElementsOp,
2587 auto fromElementsOp =
2588 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2589 if (!fromElementsOp)
2590 return failure();
2591
2592 llvm::append_range(results, fromElementsOp.getElements());
2593 return success();
2594}
2595
2596/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2597///
2598/// Example:
2599/// %b = vector.broadcast %x : i32 to vector<3xf32>
2600/// %e:3 = vector.to_elements %b : vector<3xf32>
2601/// user_op %e#0, %e#1, %e#2
2602/// becomes:
2603/// user_op %x, %x, %x
2604///
2605/// The vector source case is handled by a canonicalization pattern.
2606static LogicalResult
2607foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2609 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2610 if (!bcastOp)
2611 return failure();
2612 // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2613 if (isa<VectorType>(bcastOp.getSource().getType()))
2614 return failure();
2615
2616 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2617
2618 Value scalar = bcastOp.getSource();
2619 results.assign(resultVecType.getNumElements(), scalar);
2620 return success();
2621}
2622
2623LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2624 SmallVectorImpl<OpFoldResult> &results) {
2625 if (succeeded(foldToElementsFromElements(*this, results)))
2626 return success();
2627
2628 // Y = ToElements(ShapeCast(X)) -> Y = ToElements(X)
2629 if (auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2630 setOperand(shapeCast.getSource());
2631 return success();
2632 }
2633
2634 return foldToElementsOfBroadcast(*this, results);
2635}
2636
2637LogicalResult
2638ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2639 ToElementsOp::Adaptor adaptor,
2640 SmallVectorImpl<Type> &inferredReturnTypes) {
2641 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2642 Type elType = vecType.getElementType();
2643 inferredReturnTypes.append(vecType.getNumElements(), elType);
2644 return success();
2645}
2646
2647/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2648/// vector.
2649/// - Build `vector.to_elements %v` and remap each destination element to the
2650/// corresponding source element using broadcast rules (match or 1 →
2651/// replicate).
2652///
2653/// Example:
2654/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2655/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2656/// becomes:
2657/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2658/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2659/// // %src_elems#1, %src_elems#0, %src_elems#1
2660struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2661 using Base::Base;
2662
2663 LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2664 PatternRewriter &rewriter) const override {
2665 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2666 if (!bcastOp)
2667 return failure();
2668
2669 // Only handle broadcasts from a vector source here.
2670 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2671 if (!srcType)
2672 return failure();
2673
2674 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2675
2676 ArrayRef<int64_t> dstShape = dstType.getShape();
2677 ArrayRef<int64_t> srcShape = srcType.getShape();
2678
2679 int64_t dstRank = dstShape.size();
2680 int64_t srcRank = srcShape.size();
2681
2682 // Create elements for the broadcast source vector.
2683 auto srcElems = vector::ToElementsOp::create(
2684 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2685
2686 int64_t dstCount = llvm::product_of(dstShape);
2687
2688 SmallVector<Value> replacements;
2689 replacements.reserve(dstCount);
2690
2691 // For each element of the destination, determine which element of the
2692 // source should be used. We walk all destination positions using a single
2693 // counter, decode it into per-dimension indices, then build the matching
2694 // source position: use the same index where sizes match, and use 0 where
2695 // the source size is 1 (replication). This mapping is needed so we can
2696 // replace each result of to_elements with the corresponding element from
2697 // the broadcast source.
2698 // Inner-dimension stretch example:
2699 // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2700 // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2701 // becomes:
2702 // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2703 // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2704 // // %src_elems#1, %src_elems#0, %src_elems#1,
2705 // // %src_elems#2, %src_elems#3, %src_elems#2,
2706 // // %src_elems#3, %src_elems#2, %src_elems#3
2707
2708 // Row-major strides for the destination shape.
2709 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2710 // Row-major strides for the source shape.
2711 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2712 SmallVector<int64_t> dstIdx(dstRank);
2713 SmallVector<int64_t> srcIdx(srcRank);
2714 for (int64_t lin = 0; lin < dstCount; ++lin) {
2715 // Convert linear destination index to per-dimension indices.
2716 dstIdx = delinearize(lin, dstStrides);
2717 for (int64_t k = 0; k < srcRank; ++k)
2718 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2719 // Convert per-dimension source indices back to a linear index.
2720 int64_t srcLin = linearize(srcIdx, srcStrides);
2721 replacements.push_back(srcElems.getResult(srcLin));
2722 }
2723
2724 rewriter.replaceOp(toElementsOp, replacements);
2725 return success();
2726 }
2727};
2728
2729void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2730 MLIRContext *context) {
2731 results.add<ToElementsOfBroadcast>(context);
2732}
2733
2734//===----------------------------------------------------------------------===//
2735// FromElementsOp
2736//===----------------------------------------------------------------------===//
2737
2738/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2739///
2740/// Case #1: Input and output vectors are the same.
2741///
2742/// %0:3 = vector.to_elements %a : vector<3xf32>
2743/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2744/// user_op %1
2745///
2746/// becomes:
2747///
2748/// user_op %a
2749///
2750static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2751 OperandRange fromElemsOperands = fromElementsOp.getElements();
2752 if (fromElemsOperands.empty())
2753 return {};
2754
2755 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2756 if (!toElementsOp)
2757 return {};
2758
2759 if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2760 return {};
2761
2762 // Case #1: Input and output vectors are the same. Forward the input vector.
2763 Value toElementsInput = toElementsOp.getSource();
2764 if (fromElementsOp.getType() == toElementsInput.getType() &&
2765 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2766 return toElementsInput;
2767 }
2768
2769 // TODO: Support cases with different input and output shapes and different
2770 // number of elements.
2771
2772 return {};
2773}
2774
2775/// Fold vector.from_elements to a constant when all operands are constants.
2776/// Example:
2777/// %c1 = arith.constant 1 : i32
2778/// %c2 = arith.constant 2 : i32
2779/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2780/// =>
2781/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2782///
2783static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2784 ArrayRef<Attribute> elements) {
2785 // Check for null or poison attributes before any processing.
2786 if (llvm::any_of(elements, [](Attribute attr) {
2787 return !attr || matchPattern(attr, ub::m_Poison());
2788 }))
2789 return {};
2790
2791 // DenseElementsAttr only supports int/index/float/complex types.
2792 auto destVecType = fromElementsOp.getDest().getType();
2793 auto destEltType = destVecType.getElementType();
2794 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2795 return {};
2796
2797 // Constant attributes might have a different type than the return type.
2798 // Convert them before creating the dense elements attribute.
2799 auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2800 return convertNumericAttr(attr, destEltType);
2801 });
2802
2803 return DenseElementsAttr::get(destVecType, convertedElements);
2804}
2805
2806OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2807 if (auto res = foldFromElementsToElements(*this))
2808 return res;
2809 if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2810 return res;
2811
2812 return {};
2813}
2814
2815/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2816/// same. Example:
2817/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2818/// =>
2819/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2820static LogicalResult
2821rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2822 PatternRewriter &rewriter) {
2823 if (!llvm::all_equal(fromElementsOp.getElements()))
2824 return failure();
2825 rewriter.replaceOpWithNewOp<BroadcastOp>(
2826 fromElementsOp, fromElementsOp.getType(),
2827 fromElementsOp.getElements().front());
2828 return success();
2829}
2830
2831/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2832/// on a single extract. Example:
2833/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2834/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2835/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2836///
2837/// becomes
2838/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2839/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2840///
2841/// The requirements for this to be valid are
2842///
2843/// i) The elements are extracted from the same vector (%source).
2844///
2845/// ii) The elements form a suffix of %source. Specifically, the number
2846/// of elements is the same as the product of the last N dimension sizes
2847/// of %source, for some N.
2848///
2849/// iii) The elements are extracted contiguously in ascending order.
2850
2851class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2852
2853 using Base::Base;
2854
2855 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2856 PatternRewriter &rewriter) const override {
2857
2858 // Handled by `rewriteFromElementsAsBroadcast`.
2859 if (fromElements.getType().getNumElements() == 1)
2860 return failure();
2861
2862 // The common source that all elements are extracted from, if one exists.
2864 // The position of the combined extract operation, if one is created.
2865 ArrayRef<int64_t> combinedPosition;
2866 // The expected index of extraction of the current element in the loop, if
2867 // elements are extracted contiguously in ascending order.
2868 SmallVector<int64_t> expectedPosition;
2869
2870 for (auto [insertIndex, element] :
2871 llvm::enumerate(fromElements.getElements())) {
2872
2873 // Check that the element is from a vector.extract operation.
2874 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2875 if (!extractOp) {
2876 return rewriter.notifyMatchFailure(fromElements,
2877 "element not from vector.extract");
2878 }
2879
2880 // Check condition (i) by checking that all elements have the same source
2881 // as the first element.
2882 if (insertIndex == 0) {
2883 source = extractOp.getSource();
2884 } else if (extractOp.getSource() != source) {
2885 return rewriter.notifyMatchFailure(fromElements,
2886 "element from different vector");
2887 }
2888
2889 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2890 int64_t rank = position.size();
2891 assert(rank == source.getType().getRank() &&
2892 "scalar extract must have full rank position");
2893
2894 // Check condition (ii) by checking that the position that the first
2895 // element is extracted from has sufficient trailing 0s. For example, in
2896 //
2897 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2898 // [...]
2899 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2900 //
2901 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2902 // elements, which is the number of elements of %n, so this is valid.
2903 if (insertIndex == 0) {
2904 const int64_t numElms = fromElements.getType().getNumElements();
2905 int64_t numSuffixElms = 1;
2906 int64_t index = rank;
2907 while (index > 0 && position[index - 1] == 0 &&
2908 numSuffixElms < numElms) {
2909 numSuffixElms *= source.getType().getDimSize(index - 1);
2910 --index;
2911 }
2912 if (numSuffixElms != numElms) {
2913 return rewriter.notifyMatchFailure(
2914 fromElements, "elements do not form a suffix of source");
2915 }
2916 expectedPosition = llvm::to_vector(position);
2917 combinedPosition = position.drop_back(rank - index);
2918 }
2919
2920 // Check condition (iii).
2921 else if (expectedPosition != position) {
2922 return rewriter.notifyMatchFailure(
2923 fromElements, "elements not in ascending order (static order)");
2924 }
2925 increment(expectedPosition, source.getType().getShape());
2926 }
2927
2928 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2929 fromElements.getLoc(), source, combinedPosition);
2930
2931 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2932 fromElements, fromElements.getType(), extracted);
2933
2934 return success();
2935 }
2936
2937 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2938 static void increment(MutableArrayRef<int64_t> indices,
2940 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2941 indices[dim] += 1;
2942 if (indices[dim] < shape[dim])
2943 break;
2944 indices[dim] = 0;
2945 }
2946 }
2947};
2948
2949void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2950 MLIRContext *context) {
2952 results.add<FromElementsToShapeCast>(context);
2953}
2954
2955//===----------------------------------------------------------------------===//
2956// BroadcastOp
2957//===----------------------------------------------------------------------===//
2958
2959void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2960 SetIntRangeFn setResultRanges) {
2961 setResultRanges(getResult(), argRanges.front());
2962}
2963
2964std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2965 return llvm::to_vector<4>(getResultVectorType().getShape());
2966}
2967
2968/// Return the dimensions of the result vector that were formerly ones in the
2969/// source tensor and thus correspond to "dim-1" broadcasting.
2970static llvm::SetVector<int64_t>
2972 ArrayRef<int64_t> dstShape) {
2973 int64_t rankDiff = dstShape.size() - srcShape.size();
2974 int64_t dstDim = rankDiff;
2976 for (auto [s1, s2] :
2977 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2978 if (s1 != s2) {
2979 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2980 res.insert(dstDim);
2981 }
2982 ++dstDim;
2983 }
2984 return res;
2985}
2986
2987llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2988 // Scalar broadcast is without any unit dim broadcast.
2989 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2990 if (!srcVectorType)
2991 return {};
2992 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2993 getResultVectorType().getShape());
2994}
2995
2996/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2997/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2998/// This requires (and asserts) that the broadcast is free of "dim-1"
2999/// broadcasting.
3000/// Since vector.broadcast only allows expanding leading dimensions, an extra
3001/// vector.transpose may be inserted to make the broadcast possible.
3002/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
3003/// the helper will assert. This means:
3004/// 1. `dstShape` must not be empty.
3005/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
3006/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
3007// must match the `value` shape.
3008Value BroadcastOp::createOrFoldBroadcastOp(
3009 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
3010 const llvm::SetVector<int64_t> &broadcastedDims) {
3011 assert(!dstShape.empty() && "unexpected empty dst shape");
3012
3013 // Well-formedness check.
3014 SmallVector<int64_t> checkShape;
3015 for (int i = 0, e = dstShape.size(); i < e; ++i) {
3016 if (broadcastedDims.contains(i))
3017 continue;
3018 checkShape.push_back(dstShape[i]);
3019 }
3020 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
3021 "ill-formed broadcastedDims contains values not confined to "
3022 "destVectorShape");
3023
3024 Location loc = value.getLoc();
3025 Type elementType = getElementTypeOrSelf(value.getType());
3026 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
3027 VectorType dstVectorType = VectorType::get(dstShape, elementType);
3028
3029 // Step 2. If scalar -> dstShape broadcast, just do it.
3030 if (!srcVectorType) {
3031 assert(checkShape.empty() &&
3032 "ill-formed createOrFoldBroadcastOp arguments");
3033 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
3034 }
3035
3036 assert(srcVectorType.getShape().equals(checkShape) &&
3037 "ill-formed createOrFoldBroadcastOp arguments");
3038
3039 // Step 3. Since vector.broadcast only allows creating leading dims,
3040 // vector -> dstShape broadcast may require a transpose.
3041 // Traverse the dims in order and construct:
3042 // 1. The leading entries of the broadcastShape that is guaranteed to be
3043 // achievable by a simple broadcast.
3044 // 2. The induced permutation for the subsequent vector.transpose that will
3045 // bring us from `broadcastShape` back to he desired `dstShape`.
3046 // If the induced permutation is not the identity, create a vector.transpose.
3047 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
3048 broadcastShape.reserve(dstShape.size());
3049 // Consider the example:
3050 // srcShape = 2x4
3051 // dstShape = 1x2x3x4x5
3052 // broadcastedDims = [0, 2, 4]
3053 //
3054 // We want to build:
3055 // broadcastShape = 1x3x5x2x4
3056 // permutation = [0, 2, 4, 1, 3]
3057 // ---V--- -----V-----
3058 // leading broadcast part src shape part
3059 //
3060 // Note that the trailing dims of broadcastShape are exactly the srcShape
3061 // by construction.
3062 // nextSrcShapeDim is used to keep track of where in the permutation the
3063 // "src shape part" occurs.
3064 int64_t nextSrcShapeDim = broadcastedDims.size();
3065 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
3066 if (broadcastedDims.contains(i)) {
3067 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
3068 // bring it to the head of the broadcastShape.
3069 // It will need to be permuted back from `broadcastShape.size() - 1` into
3070 // position `i`.
3071 broadcastShape.push_back(dstShape[i]);
3072 permutation[i] = broadcastShape.size() - 1;
3073 } else {
3074 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
3075 // shape and needs to be permuted into position `i`.
3076 // Don't touch `broadcastShape` here, the whole srcShape will be
3077 // appended after.
3078 permutation[i] = nextSrcShapeDim++;
3079 }
3080 }
3081 // 3.c. Append the srcShape.
3082 llvm::append_range(broadcastShape, srcVectorType.getShape());
3083
3084 // Ensure there are no "dim-1" broadcasts.
3085 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
3086 .empty() &&
3087 "unexpected \"dim-1\" broadcast");
3088
3089 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3090 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
3091 vector::BroadcastableToResult::Success &&
3092 "must be broadcastable");
3093 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3094 // Step 4. If we find any dimension that indeed needs to be permuted,
3095 // immediately return a new vector.transpose.
3096 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3097 if (permutation[i] != i)
3098 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3099 // Otherwise return res.
3100 return res;
3101}
3102
3104 Type srcType, VectorType dstVectorType,
3105 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3106 // Broadcast scalar to vector of the same element type.
3107 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3108 srcType == getElementTypeOrSelf(dstVectorType))
3110 // From now on, only vectors broadcast.
3111 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3112 if (!srcVectorType)
3114
3115 int64_t srcRank = srcVectorType.getRank();
3116 int64_t dstRank = dstVectorType.getRank();
3117 if (srcRank > dstRank)
3119 // Source has an exact match or singleton value for all trailing dimensions
3120 // (all leading dimensions are simply duplicated).
3121 int64_t lead = dstRank - srcRank;
3122 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3123 // Have mismatching dims (in the sense of vector.broadcast semantics) been
3124 // encountered?
3125 bool foundMismatchingDims = false;
3126
3127 // Check fixed-width dims.
3128 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3129 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3130 if (srcDim != 1 && srcDim != dstDim)
3131 foundMismatchingDims = true;
3132
3133 // Check scalable flags.
3134 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3135 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3136 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3137 // 1 -> [N] is fine, everything else should be rejected when mixing
3138 // fixed-width and scalable dims
3139 (srcDimScalableFlag != dstDimScalableFlag &&
3140 (srcDim != 1 || srcDimScalableFlag)))
3141 foundMismatchingDims = true;
3142
3143 if (foundMismatchingDims) {
3144 if (mismatchingDims != nullptr) {
3145 mismatchingDims->first.dim = srcDim;
3146 mismatchingDims->first.isScalable = srcDimScalableFlag;
3147
3148 mismatchingDims->second.dim = dstDim;
3149 mismatchingDims->second.isScalable = dstDimScalableFlag;
3150 }
3152 }
3153 }
3154
3156}
3157
3158LogicalResult BroadcastOp::verify() {
3159 std::pair<VectorDim, VectorDim> mismatchingDims;
3161 getSourceType(), getResultVectorType(), &mismatchingDims);
3163 return success();
3165 return emitOpError("source rank higher than destination rank");
3167 return emitOpError("dimension mismatch (")
3168 << (mismatchingDims.first.isScalable ? "[" : "")
3169 << mismatchingDims.first.dim
3170 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
3171 << (mismatchingDims.second.isScalable ? "[" : "")
3172 << mismatchingDims.second.dim
3173 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
3174 }
3176 return emitOpError("source type is not a vector");
3177 llvm_unreachable("unexpected vector.broadcast op error");
3178}
3179
3180// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
3181// with broadcast's result type and shape_cast only adds or removes ones in the
3182// leading dimensions.
3183static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
3184 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3185 if (!srcShapeCast)
3186 return failure();
3187
3188 VectorType srcType = srcShapeCast.getSourceVectorType();
3189 VectorType destType = broadcastOp.getResultVectorType();
3190 // Check type compatibility.
3191 if (vector::isBroadcastableTo(srcType, destType) !=
3193 return failure();
3194
3195 ArrayRef<int64_t> srcShape = srcType.getShape();
3196 ArrayRef<int64_t> shapecastShape =
3197 srcShapeCast.getResultVectorType().getShape();
3198 // Trailing dimensions should be the same if shape_cast only alters the
3199 // leading dimensions.
3200 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3201 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3202 shapecastShape.take_back(numTrailingDims)))
3203 return failure();
3204
3205 assert(all_of(srcShape.drop_back(numTrailingDims),
3206 [](int64_t E) { return E == 1; }) &&
3207 all_of(shapecastShape.drop_back(numTrailingDims),
3208 [](int64_t E) { return E == 1; }) &&
3209 "ill-formed shape_cast");
3210
3211 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3212 return success();
3213}
3214
3215OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3216 if (getSourceType() == getResultVectorType())
3217 return getSource();
3218 if (succeeded(foldBroadcastOfShapeCast(*this)))
3219 return getResult();
3220
3221 if (!adaptor.getSource())
3222 return {};
3223 auto vectorType = getResultVectorType();
3224 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3225 if (vectorType.getElementType() != attr.getType())
3226 return {};
3227 return DenseElementsAttr::get(vectorType, attr);
3228 }
3229 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3230 if (vectorType.getElementType() != attr.getType())
3231 return {};
3232 return DenseElementsAttr::get(vectorType, attr);
3233 }
3234 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3235 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3236 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
3237 return ub::PoisonAttr::get(getContext());
3238 return {};
3239}
3240
3241namespace {
3242
3243// Fold broadcast1(broadcast2(x)) into broadcast1(x).
3244struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3245 using Base::Base;
3246
3247 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3248 PatternRewriter &rewriter) const override {
3249 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3250 if (!srcBroadcast)
3251 return failure();
3252 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3253 broadcastOp.getResultVectorType(),
3254 srcBroadcast.getSource());
3255 return success();
3256 }
3257};
3258
3259/// Replace `vector.broadcast` with `vector.shape_cast`.
3260///
3261/// BEFORE:
3262/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3263/// AFTER:
3264/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3265///
3266/// The canonical form of vector operations that reshape vectors is shape_cast.
3267struct BroadcastToShapeCast final
3268 : public OpRewritePattern<vector::BroadcastOp> {
3269 using Base::Base;
3270 LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
3271 PatternRewriter &rewriter) const override {
3272
3273 auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
3274 if (!sourceType) {
3275 return rewriter.notifyMatchFailure(
3276 broadcast, "source is a scalar, shape_cast doesn't support scalar");
3277 }
3278
3279 VectorType outType = broadcast.getType();
3280 if (sourceType.getNumElements() != outType.getNumElements()) {
3281 return rewriter.notifyMatchFailure(
3282 broadcast, "broadcast to a greater number of elements");
3283 }
3284
3285 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
3286 broadcast.getSource());
3287 return success();
3288 }
3289};
3290} // namespace
3291
3292void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3293 MLIRContext *context) {
3294 results.add<BroadcastFolder, BroadcastToShapeCast>(context);
3295}
3296
3297//===----------------------------------------------------------------------===//
3298// ShuffleOp
3299//===----------------------------------------------------------------------===//
3300
3301LogicalResult ShuffleOp::verify() {
3302 VectorType resultType = getResultVectorType();
3303 VectorType v1Type = getV1VectorType();
3304 VectorType v2Type = getV2VectorType();
3305 // Verify ranks.
3306 int64_t resRank = resultType.getRank();
3307 int64_t v1Rank = v1Type.getRank();
3308 int64_t v2Rank = v2Type.getRank();
3309 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3310 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3311 if (!wellFormed0DCase && !wellFormedNDCase)
3312 return emitOpError("rank mismatch");
3313
3314 // Verify all but leading dimension sizes.
3315 for (int64_t r = 1; r < v1Rank; ++r) {
3316 int64_t resDim = resultType.getDimSize(r);
3317 int64_t v1Dim = v1Type.getDimSize(r);
3318 int64_t v2Dim = v2Type.getDimSize(r);
3319 if (resDim != v1Dim || v1Dim != v2Dim)
3320 return emitOpError("dimension mismatch");
3321 }
3322 // Verify mask length.
3323 ArrayRef<int64_t> mask = getMask();
3324 int64_t maskLength = mask.size();
3325 if (maskLength <= 0)
3326 return emitOpError("invalid mask length");
3327 if (maskLength != resultType.getDimSize(0))
3328 return emitOpError("mask length mismatch");
3329 // Verify all indices.
3330 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3331 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3332 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3333 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3334 return emitOpError("mask index #") << (idx + 1) << " out of range";
3335 }
3336 return success();
3337}
3338
3339LogicalResult
3340ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3341 ShuffleOp::Adaptor adaptor,
3342 SmallVectorImpl<Type> &inferredReturnTypes) {
3343 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3344 if (!v1Type) {
3345 return emitOptionalError(loc, "expected vector type");
3346 }
3347 auto v1Rank = v1Type.getRank();
3348 // Construct resulting type: leading dimension matches mask
3349 // length, all trailing dimensions match the operands.
3350 SmallVector<int64_t, 4> shape;
3351 shape.reserve(v1Rank);
3352 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3353 // In the 0-D case there is no trailing shape to append.
3354 if (v1Rank > 0)
3355 llvm::append_range(shape, v1Type.getShape().drop_front());
3356 inferredReturnTypes.push_back(
3357 VectorType::get(shape, v1Type.getElementType()));
3358 return success();
3359}
3360
3361template <typename T>
3362static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3363 T expected = begin;
3364 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3365 return value == expected++;
3366 });
3367}
3368
3369/// Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3370/// Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3372 auto v1Type = op.getV1VectorType();
3373 auto v2Type = op.getV2VectorType();
3374 auto mask = op.getMask();
3375 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3376 return op.getV1();
3377 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3378 return op.getV2();
3379 return {};
3380}
3381
3382/// If a shuffle operand is poison, replace all mask indices that reference it
3383/// with kPoisonIndex. This is an in-place fold.
3385 bool isV1Poison = matchPattern(op.getV1(), ub::m_Poison());
3386 bool isV2Poison = matchPattern(op.getV2(), ub::m_Poison());
3387 if (!isV1Poison && !isV2Poison)
3388 return {};
3389
3390 int64_t v1Size = op.getV1VectorType().getDimSize(0);
3391 bool changed = false;
3392 SmallVector<int64_t> newMask = llvm::to_vector(op.getMask());
3393 for (int64_t &idx : newMask) {
3394 if (idx == ShuffleOp::kPoisonIndex)
3395 continue;
3396 if ((isV1Poison && idx < v1Size) || (isV2Poison && idx >= v1Size)) {
3397 idx = ShuffleOp::kPoisonIndex;
3398 changed = true;
3399 }
3400 }
3401
3402 if (!changed)
3403 return {};
3404
3405 op.setMask(newMask);
3406 return op.getResult();
3407}
3408
3409/// Fold shuffle poison, poison -> poison.
3411 Attribute v1Attr,
3412 Attribute v2Attr) {
3413 if (matchPattern(v1Attr, ub::m_Poison()) &&
3414 matchPattern(v2Attr, ub::m_Poison()))
3415 return ub::PoisonAttr::get(context);
3416 return {};
3417}
3418
3419/// Fold a shuffle of constant 1-D inputs by evaluating the mask.
3421 Attribute v2Attr) {
3422 auto v1Type = op.getV1VectorType();
3423 if (v1Type.getRank() != 1)
3424 return {};
3425
3426 bool isV1Poison = matchPattern(v1Attr, ub::m_Poison());
3427 bool isV2Poison = matchPattern(v2Attr, ub::m_Poison());
3428
3429 // Poison input attributes need special handling as they are not
3430 // DenseElementsAttr. If an index is poison, we select the first element of
3431 // the first non-poison input.
3432 SmallVector<Attribute> v1Elements, v2Elements;
3433 Attribute poisonElement;
3434 if (!isV2Poison) {
3435 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3436 if (!v2DenseAttr)
3437 return {};
3438 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3439 poisonElement = v2Elements[0];
3440 }
3441 if (!isV1Poison) {
3442 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3443 if (!v1DenseAttr)
3444 return {};
3445 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3446 poisonElement = v1Elements[0];
3447 }
3448
3449 ArrayRef<int64_t> mask = op.getMask();
3450 SmallVector<Attribute> results;
3451 int64_t v1Size = v1Type.getDimSize(0);
3452 for (int64_t maskIdx : mask) {
3453 Attribute indexedElm;
3454 // TODO: Return a partial poison vector when supported by the UB dialect.
3455 if (maskIdx == ShuffleOp::kPoisonIndex) {
3456 indexedElm = poisonElement;
3457 } else {
3458 if (maskIdx < v1Size)
3459 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3460 else
3461 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3462 }
3463
3464 results.push_back(indexedElm);
3465 }
3466
3467 return DenseElementsAttr::get(op.getResultVectorType(), results);
3468}
3469
3470OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3471 auto v1Type = getV1VectorType();
3472
3473 assert(!v1Type.isScalable() && !getV2VectorType().isScalable() &&
3474 "Vector shuffle does not support scalable vectors");
3475
3476 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3477 // but must be a canonicalization into a vector.broadcast.
3478 if (v1Type.getRank() == 0)
3479 return {};
3480
3481 if (auto res = foldShuffleIdentityMask(*this))
3482 return res;
3483 if (auto res = foldShufflePoisonOperandToMask(*this))
3484 return res;
3485
3486 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3487 if (!v1Attr || !v2Attr)
3488 return {};
3489
3490 if (auto res = foldShufflePoisonInputs(getContext(), v1Attr, v2Attr))
3491 return res;
3492 if (auto res = foldShuffleConstantInputs(*this, v1Attr, v2Attr))
3493 return res;
3494
3495 return {};
3496}
3497
3498namespace {
3499
3500// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3501// to a broadcast.
3502struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3503 using Base::Base;
3504
3505 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3506 PatternRewriter &rewriter) const override {
3507 VectorType v1VectorType = shuffleOp.getV1VectorType();
3508 ArrayRef<int64_t> mask = shuffleOp.getMask();
3509 if (v1VectorType.getRank() > 0)
3510 return failure();
3511 if (mask.size() != 1)
3512 return failure();
3513 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3514 if (mask[0] == 0)
3515 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3516 shuffleOp.getV1());
3517 else
3518 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3519 shuffleOp.getV2());
3520 return success();
3521 }
3522};
3523
3524/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3525/// vector.broadcast with a scalar operand, return the scalar value that is
3526/// splatted. Otherwise return null.
3527///
3528/// Example:
3529///
3530/// scalar_source --> vector.broadcast --> value - return scalar_source
3531static Value getScalarSplatSource(Value value) {
3532 // Block argument:
3533 Operation *defOp = value.getDefiningOp();
3534 if (!defOp)
3535 return {};
3536
3537 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3538
3539 // Not broadcast (and not splat):
3540 if (!broadcast)
3541 return {};
3542
3543 // Broadcast of a vector:
3544 if (isa<VectorType>(broadcast.getSourceType()))
3545 return {};
3546
3547 // Broadcast of a scalar:
3548 return broadcast.getSource();
3549}
3550
3551/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3552class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3553public:
3554 using Base::Base;
3555
3556 LogicalResult matchAndRewrite(ShuffleOp op,
3557 PatternRewriter &rewriter) const override {
3558 Value splat = getScalarSplatSource(op.getV1());
3559 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3560 return failure();
3561
3562 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3563 return success();
3564 }
3565};
3566
3567/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3568/// vector.interleave.
3569class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3570public:
3571 using Base::Base;
3572
3573 LogicalResult matchAndRewrite(ShuffleOp op,
3574 PatternRewriter &rewriter) const override {
3575 VectorType resultType = op.getResultVectorType();
3576 if (resultType.isScalable())
3577 return rewriter.notifyMatchFailure(
3578 op, "ShuffleOp can't represent a scalable interleave");
3579
3580 if (resultType.getRank() != 1)
3581 return rewriter.notifyMatchFailure(
3582 op, "ShuffleOp can't represent an n-D interleave");
3583
3584 VectorType sourceType = op.getV1VectorType();
3585 if (sourceType != op.getV2VectorType() ||
3586 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3587 return rewriter.notifyMatchFailure(
3588 op, "ShuffleOp types don't match an interleave");
3589 }
3590
3591 ArrayRef<int64_t> shuffleMask = op.getMask();
3592 int64_t resultVectorSize = resultType.getNumElements();
3593 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3594 int64_t maskValueA = shuffleMask[i * 2];
3595 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3596 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3597 return rewriter.notifyMatchFailure(op,
3598 "ShuffleOp mask not interleaving");
3599 }
3600
3601 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3602 return success();
3603 }
3604};
3605
3606/// Pattern to replace usused shuffle operands / results with poison.
3607///
3608/// Example Input:
3609/// %r = vector.shuffle %v1, %v2 [2, 3, 3, 3] : vector<2xi32>, vector<2xi32>
3610///
3611/// Example Output:
3612/// %0 = ub.poison : vector<2xi32>
3613/// %r = vector.shuffle %0, %v2 [2, 3, 3, 3] : vector<2xi32>, vector<2xi32>
3614class FoldUnusedShuffleOperand final : public OpRewritePattern<ShuffleOp> {
3615public:
3616 using Base::Base;
3617
3618 LogicalResult matchAndRewrite(ShuffleOp op,
3619 PatternRewriter &rewriter) const override {
3620 // Replace with poison if all mask elements are poison.
3621 if (llvm::all_of(op.getMask(), [](int64_t mask) {
3622 return mask == ShuffleOp::kPoisonIndex;
3623 })) {
3624 rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType());
3625 return success();
3626 }
3627
3628 // Helper function to replace an operand with poison.
3629 auto replaceOperandWithPoison = [&](OpOperand &operand) {
3630 // Do not replace if the operand is already poison.
3631 if (!matchPattern(operand.get(), ub::m_Poison())) {
3632 Value poison = ub::PoisonOp::create(rewriter, op.getLoc(),
3633 operand.get().getType());
3634 rewriter.modifyOpInPlace(op, [&]() { operand.set(poison); });
3635 return success();
3636 }
3637 return failure();
3638 };
3639
3640 // Replace V1 with poison if it is not used.
3641 int64_t leadingV1Size = op.getV1VectorType().getRank() > 0
3642 ? op.getV1VectorType().getDimSize(0)
3643 : 1;
3644 bool isV1Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3645 return mask != ShuffleOp::kPoisonIndex && mask < leadingV1Size;
3646 });
3647 if (!isV1Used && succeeded(replaceOperandWithPoison(op.getV1Mutable())))
3648 return success();
3649
3650 // Replace V2 with poison if it is not used.
3651 bool isV2Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3652 return mask != ShuffleOp::kPoisonIndex && mask >= leadingV1Size;
3653 });
3654 if (!isV2Used && succeeded(replaceOperandWithPoison(op.getV2Mutable())))
3655 return success();
3656
3657 return failure();
3658 }
3659};
3660} // namespace
3661
3662void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3663 MLIRContext *context) {
3664 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp,
3665 FoldUnusedShuffleOperand>(context);
3666}
3667
3668//===----------------------------------------------------------------------===//
3669// InsertOp
3670//===----------------------------------------------------------------------===//
3671
3672void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3673 SetIntRangeFn setResultRanges) {
3674 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3675}
3676
3677void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3678 Value source, Value dest) {
3679 auto vectorTy = cast<VectorType>(dest.getType());
3680 build(builder, result, source, dest,
3681 SmallVector<int64_t>(vectorTy.getRank(), 0));
3682}
3683
3684void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3685 Value source, Value dest, int64_t position) {
3686 build(builder, result, source, dest, ArrayRef<int64_t>{position});
3687}
3688
3689void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3690 Value source, Value dest, OpFoldResult position) {
3691 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3692}
3693
3694void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3695 Value source, Value dest,
3696 ArrayRef<int64_t> position) {
3697 SmallVector<OpFoldResult> posVals;
3698 posVals.reserve(position.size());
3699 llvm::transform(position, std::back_inserter(posVals),
3700 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3701 build(builder, result, source, dest, posVals);
3702}
3703
3704void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3705 Value source, Value dest,
3706 ArrayRef<OpFoldResult> position) {
3707 SmallVector<int64_t> staticPos;
3708 SmallVector<Value> dynamicPos;
3709 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3710 build(builder, result, source, dest, dynamicPos,
3711 builder.getDenseI64ArrayAttr(staticPos));
3712}
3713
3714LogicalResult InsertOp::verify() {
3715 if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3716 if (srcTy.getRank() == 0)
3717 return emitError(
3718 "expected a scalar instead of a 0-d vector as the source operand");
3719
3720 SmallVector<OpFoldResult> position = getMixedPosition();
3721 auto destVectorType = getDestVectorType();
3722 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3723 return emitOpError(
3724 "expected position attribute of rank no greater than dest vector rank");
3725 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3726 if (srcVectorType &&
3727 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3728 static_cast<unsigned>(destVectorType.getRank())))
3729 return emitOpError("expected position attribute rank + source rank to "
3730 "match dest vector rank");
3731 if (!srcVectorType &&
3732 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3733 return emitOpError(
3734 "expected position attribute rank to match the dest vector rank");
3735 for (auto [idx, pos] : llvm::enumerate(position)) {
3736 if (auto attr = dyn_cast<Attribute>(pos)) {
3737 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3738 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3739 destVectorType.getDimSize(idx))) {
3740 return emitOpError("expected position attribute #")
3741 << (idx + 1)
3742 << " to be a non-negative integer smaller than the "
3743 "corresponding "
3744 "dest vector dimension";
3745 }
3746 }
3747 }
3748 return success();
3749}
3750
3751// Calculate the linearized position of the continuous chunk of elements to
3752// insert, based on the shape of the value to insert and the positions to insert
3753// at.
3754static int64_t calculateInsertPosition(VectorType destTy,
3755 ArrayRef<int64_t> positions) {
3756 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3757 assert(positions.size() <= completePositions.size() &&
3758 "positions size must be less than or equal to destTy rank");
3759 copy(positions, completePositions.begin());
3760 return linearize(completePositions, computeStrides(destTy.getShape()));
3761}
3762
3763namespace {
3764
3765// If insertOp is only inserting unit dimensions it can be transformed to a
3766// broadcast.
3767class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3768public:
3769 using Base::Base;
3770
3771 LogicalResult matchAndRewrite(InsertOp insertOp,
3772 PatternRewriter &rewriter) const override {
3773 auto srcVecType =
3774 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3775 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3776 srcVecType.getNumElements())
3777 return failure();
3778 rewriter.replaceOpWithNewOp<BroadcastOp>(
3779 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3780 return success();
3781 }
3782};
3783
3784/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3785class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3786public:
3787 using Base::Base;
3788
3789 LogicalResult matchAndRewrite(InsertOp op,
3790 PatternRewriter &rewriter) const override {
3791
3792 Value splat = getScalarSplatSource(op.getValueToStore());
3793 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3794 return failure();
3795
3796 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3797 return success();
3798 }
3799};
3800
3801/// Pattern to optimize a chain of insertions.
3802///
3803/// This pattern identifies chains of vector.insert operations that:
3804/// 1. Only insert values at static positions.
3805/// 2. Completely initialize all elements in the resulting vector.
3806/// 3. All intermediate insert operations have only one use.
3807///
3808/// When these conditions are met, the entire chain can be replaced with a
3809/// single vector.from_elements operation.
3810///
3811/// To keep this pattern simple, and avoid spending too much time on matching
3812/// fragmented insert chains, this pattern only considers the last insert op in
3813/// the chain.
3814///
3815/// Example transformation:
3816/// %poison = ub.poison : vector<2xi32>
3817/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3818/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3819/// ->
3820/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3821class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3822public:
3823 using Base::Base;
3824 LogicalResult matchAndRewrite(InsertOp op,
3825 PatternRewriter &rewriter) const override {
3826
3827 VectorType destTy = op.getDestVectorType();
3828 if (destTy.isScalable())
3829 return failure();
3830 // Ensure this is the trailing vector.insert op in a chain of inserts.
3831 for (Operation *user : op.getResult().getUsers())
3832 if (auto insertOp = dyn_cast<InsertOp>(user))
3833 if (insertOp.getDest() == op.getResult())
3834 return failure();
3835
3836 InsertOp currentOp = op;
3837 SmallVector<InsertOp> chainInsertOps;
3838 while (currentOp) {
3839 // Check cond 1: Dynamic position is not supported.
3840 if (currentOp.hasDynamicPosition())
3841 return failure();
3842
3843 chainInsertOps.push_back(currentOp);
3844 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3845 // Check cond 3: Intermediate inserts have only one use to avoid an
3846 // explosion of vectors.
3847 if (currentOp && !currentOp->hasOneUse())
3848 return failure();
3849 }
3850
3851 int64_t vectorSize = destTy.getNumElements();
3852 int64_t initializedCount = 0;
3853 SmallVector<bool> initializedDestIdxs(vectorSize, false);
3854 SmallVector<int64_t> pendingInsertPos;
3855 SmallVector<int64_t> pendingInsertSize;
3856 SmallVector<Value> pendingInsertValues;
3857
3858 for (auto insertOp : chainInsertOps) {
3859 // This pattern can do nothing with poison index.
3860 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3861 return failure();
3862
3863 // Calculate the linearized position for inserting elements.
3864 int64_t insertBeginPosition =
3865 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3866
3867 // The valueToStore operand may be a vector or a scalar. Need to handle
3868 // both cases.
3869 int64_t insertSize = 1;
3870 if (auto srcVectorType =
3871 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3872 insertSize = srcVectorType.getNumElements();
3873
3874 assert(insertBeginPosition + insertSize <= vectorSize &&
3875 "insert would overflow the vector");
3876
3877 for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3878 insertBeginPosition + insertSize)) {
3879 if (initializedDestIdxs[index])
3880 continue;
3881 initializedDestIdxs[index] = true;
3882 ++initializedCount;
3883 }
3884
3885 // Defer the creation of ops before we can make sure the pattern can
3886 // succeed.
3887 pendingInsertPos.push_back(insertBeginPosition);
3888 pendingInsertSize.push_back(insertSize);
3889 pendingInsertValues.push_back(insertOp.getValueToStore());
3890
3891 if (initializedCount == vectorSize)
3892 break;
3893 }
3894
3895 // Check cond 2: all positions must be initialized.
3896 if (initializedCount != vectorSize)
3897 return failure();
3898
3899 SmallVector<Value> elements(vectorSize);
3900 for (auto [insertBeginPosition, insertSize, valueToStore] :
3901 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3902 pendingInsertValues))) {
3903 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3904
3905 if (!srcVectorType) {
3906 elements[insertBeginPosition] = valueToStore;
3907 continue;
3908 }
3909
3910 Repeated<Type> elementToInsertTypes(insertSize,
3911 srcVectorType.getElementType());
3912 // Get all elements from the vector in row-major order.
3913 auto elementsToInsert = vector::ToElementsOp::create(
3914 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3915 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3916 elements[insertBeginPosition + linearIdx] =
3917 elementsToInsert.getResult(linearIdx);
3918 }
3919 }
3920
3921 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3922 return success();
3923 }
3924};
3925
3926} // namespace
3927
3928static Attribute
3930 Attribute dstAttr,
3931 int64_t maxVectorSizeFoldThreshold) {
3932 if (insertOp.hasDynamicPosition())
3933 return {};
3934
3935 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3936 if (!denseDst)
3937 return {};
3938
3939 if (!srcAttr) {
3940 return {};
3941 }
3942
3943 VectorType destTy = insertOp.getDestVectorType();
3944 if (destTy.isScalable())
3945 return {};
3946
3947 // Make sure we do not create too many large constants.
3948 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3949 !insertOp->hasOneUse())
3950 return {};
3951
3952 // Bail out on poison indices (kPoisonIndex = -1) to avoid computing an
3953 // invalid (negative) linearized position which would cause UB below.
3954 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3955 return {};
3956
3957 // Calculate the linearized position for inserting elements.
3958 int64_t insertBeginPosition =
3959 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3960 SmallVector<Attribute> insertedValues;
3961 Type destEltType = destTy.getElementType();
3962
3963 /// Converts attribute to the expected type if there's
3964 /// a mismatch.
3965 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3966 for (auto value : denseSource.getValues<Attribute>())
3967 insertedValues.push_back(convertNumericAttr(value, destEltType));
3968 } else {
3969 insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3970 }
3971
3972 auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3973 copy(insertedValues, allValues.begin() + insertBeginPosition);
3974 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3975
3976 return newAttr;
3977}
3978
3979/// Folder to replace the `dest` operand of the insert op with the root dest of
3980/// the insert op use chain.
3981static Value foldInsertUseChain(InsertOp insertOp) {
3982 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3983 if (!destInsert)
3984 return {};
3985
3986 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3987 return {};
3988
3989 insertOp.setOperand(1, destInsert.getDest());
3990 return insertOp.getResult();
3991}
3992
3993void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3994 MLIRContext *context) {
3995 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3996 InsertChainFullyInitialized>(context);
3997}
3998
3999OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
4000 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
4001 // unless the source vector constant has a single use.
4002 constexpr int64_t vectorSizeFoldThreshold = 256;
4003 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
4004 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
4005 // (type mismatch).
4006 if (getNumIndices() == 0 && getValueToStoreType() == getType())
4007 return getValueToStore();
4008 // Fold `arith.constant` indices into the `vector.insert` operation.
4009 // Do not stop here as this fold may enable subsequent folds that require
4010 // constant indices.
4011 SmallVector<Value> operands = {getValueToStore(), getDest()};
4012 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
4013
4014 if (auto res = foldInsertUseChain(*this))
4015 return res;
4016 if (auto res = foldPoisonIndexInsertExtractOp(
4017 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
4018 return res;
4019 if (auto res = foldDenseElementsAttrDestInsertOp(
4020 *this, adaptor.getValueToStore(), adaptor.getDest(),
4021 vectorSizeFoldThreshold)) {
4022 return res;
4023 }
4024
4025 return inplaceFolded;
4026}
4027
4028//===----------------------------------------------------------------------===//
4029// InsertStridedSliceOp
4030//===----------------------------------------------------------------------===//
4031
4032void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4033 Value source, Value dest,
4034 ArrayRef<int64_t> offsets,
4035 ArrayRef<int64_t> strides) {
4036 result.addOperands({source, dest});
4037 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4038 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4039 result.addTypes(dest.getType());
4040 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
4041 offsetsAttr);
4042 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
4043 stridesAttr);
4044}
4045
4046// TODO: Should be moved to Tablegen ConfinedAttr attributes.
4047template <typename OpType>
4048static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
4049 ArrayAttr arrayAttr,
4051 StringRef attrName) {
4052 if (arrayAttr.size() > shape.size())
4053 return op.emitOpError("expected ")
4054 << attrName << " attribute of rank no greater than vector rank";
4055 return success();
4056}
4057
4058// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
4059// interval. If `halfOpen` is true then the admissible interval is [min, max).
4060// Otherwise, the admissible interval is [min, max].
4061template <typename OpType>
4062static LogicalResult
4064 int64_t max, StringRef attrName,
4065 bool halfOpen = true) {
4066 for (auto attr : arrayAttr) {
4067 auto val = llvm::cast<IntegerAttr>(attr).getInt();
4068 auto upper = max;
4069 if (!halfOpen)
4070 upper += 1;
4071 if (val < min || val >= upper)
4072 return op.emitOpError("expected ") << attrName << " to be confined to ["
4073 << min << ", " << upper << ")";
4074 }
4075 return success();
4076}
4077
4078// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
4079// interval. If `halfOpen` is true then the admissible interval is [min, max).
4080// Otherwise, the admissible interval is [min, max].
4081template <typename OpType>
4082static LogicalResult
4084 ArrayRef<int64_t> shape, StringRef attrName,
4085 bool halfOpen = true, int64_t min = 0) {
4086 for (auto [index, attrDimPair] :
4087 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
4088 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
4089 int64_t max = std::get<1>(attrDimPair);
4090 if (!halfOpen)
4091 max += 1;
4092 if (val < min || val >= max)
4093 return op.emitOpError("expected ")
4094 << attrName << " dimension " << index << " to be confined to ["
4095 << min << ", " << max << ")";
4096 }
4097 return success();
4098}
4099
4100// Returns true if, for all indices i = 0..shape.size()-1, val is in the
4101// [min, max} interval:
4102// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
4103// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
4104// the admissible interval is [min, max].
4105template <typename OpType>
4107 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
4108 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
4109 bool halfOpen = true, int64_t min = 1) {
4110 assert(arrayAttr1.size() <= shape.size());
4111 assert(arrayAttr2.size() <= shape.size());
4112 for (auto [index, it] :
4113 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
4114 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
4115 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
4116 int64_t max = std::get<2>(it);
4117 if (!halfOpen)
4118 max += 1;
4119 if (val1 + val2 < 0 || val1 + val2 >= max)
4120 return op.emitOpError("expected sum(")
4121 << attrName1 << ", " << attrName2 << ") dimension " << index
4122 << " to be confined to [" << min << ", " << max << ")";
4123 }
4124 return success();
4125}
4126
4128 MLIRContext *context) {
4129 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
4130 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
4131 });
4132 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
4133}
4134
4135LogicalResult InsertStridedSliceOp::verify() {
4136 auto sourceVectorType = getSourceVectorType();
4137 auto destVectorType = getDestVectorType();
4138 auto offsets = getOffsetsAttr();
4139 auto strides = getStridesAttr();
4140 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
4141 return emitOpError(
4142 "expected offsets of same size as destination vector rank");
4143 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
4144 return emitOpError("expected strides of same size as source vector rank");
4145 if (sourceVectorType.getRank() > destVectorType.getRank())
4146 return emitOpError(
4147 "expected source rank to be no greater than destination rank");
4148
4149 auto sourceShape = sourceVectorType.getShape();
4150 auto destShape = destVectorType.getShape();
4151 SmallVector<int64_t, 4> sourceShapeAsDestShape(
4152 destShape.size() - sourceShape.size(), 0);
4153 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
4154 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
4155 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
4156 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
4157 offName)) ||
4158 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4159 /*max=*/1, stridesName,
4160 /*halfOpen=*/false)) ||
4162 *this, offsets,
4163 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
4164 offName, "source vector shape",
4165 /*halfOpen=*/false, /*min=*/1)))
4166 return failure();
4167
4168 unsigned rankDiff = destShape.size() - sourceShape.size();
4169 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
4170 if (sourceVectorType.getScalableDims()[idx] !=
4171 destVectorType.getScalableDims()[idx + rankDiff]) {
4172 return emitOpError("mismatching scalable flags (at source vector idx=")
4173 << idx << ")";
4174 }
4175 if (sourceVectorType.getScalableDims()[idx]) {
4176 auto sourceSize = sourceShape[idx];
4177 auto destSize = destShape[idx + rankDiff];
4178 if (sourceSize != destSize) {
4179 return emitOpError("expected size at idx=")
4180 << idx
4181 << (" to match the corresponding base size from the input "
4182 "vector (")
4183 << sourceSize << (" vs ") << destSize << (")");
4184 }
4185 }
4186 }
4187
4188 return success();
4189}
4190
4191namespace {
4192/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
4193class FoldInsertStridedSliceSplat final
4194 : public OpRewritePattern<InsertStridedSliceOp> {
4195public:
4196 using Base::Base;
4197
4198 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4199 PatternRewriter &rewriter) const override {
4200
4201 auto dst = insertStridedSliceOp.getDest();
4202 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4203 if (!splat || getScalarSplatSource(dst) != splat)
4204 return failure();
4205
4206 rewriter.replaceOp(insertStridedSliceOp, dst);
4207 return success();
4208 }
4209};
4210
4211/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
4212/// to dst.
4213class FoldInsertStridedSliceOfExtract final
4214 : public OpRewritePattern<InsertStridedSliceOp> {
4215public:
4216 using Base::Base;
4217
4218 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4219 PatternRewriter &rewriter) const override {
4220 auto extractStridedSliceOp =
4221 insertStridedSliceOp.getValueToStore()
4222 .getDefiningOp<vector::ExtractStridedSliceOp>();
4223
4224 if (!extractStridedSliceOp)
4225 return failure();
4226
4227 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4228 return failure();
4229
4230 // Check if have the same strides and offsets.
4231 if (extractStridedSliceOp.getStrides() !=
4232 insertStridedSliceOp.getStrides() ||
4233 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4234 return failure();
4235
4236 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4237 return success();
4238 }
4239};
4240
4241// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
4242// ConstantOp.
4243class InsertStridedSliceConstantFolder final
4244 : public OpRewritePattern<InsertStridedSliceOp> {
4245public:
4246 using Base::Base;
4247
4248 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
4249 // unless the source vector constant has a single use.
4250 static constexpr int64_t vectorSizeFoldThreshold = 256;
4251
4252 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4253 PatternRewriter &rewriter) const override {
4254 // Return if 'InsertOp' operand is not defined by a compatible vector
4255 // ConstantOp.
4256 TypedValue<VectorType> destVector = op.getDest();
4257 Attribute vectorDestCst;
4258 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
4259 return failure();
4260
4261 VectorType destTy = destVector.getType();
4262 if (destTy.isScalable())
4263 return failure();
4264
4265 // Make sure we do not create too many large constants.
4266 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4267 !destVector.hasOneUse())
4268 return failure();
4269
4270 TypedValue<VectorType> sourceValue = op.getValueToStore();
4271 Attribute sourceCst;
4272 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
4273 return failure();
4274
4275 // TODO: Support poison.
4276 if (matchPattern(vectorDestCst, ub::m_Poison()) ||
4277 matchPattern(sourceCst, ub::m_Poison()))
4278 return failure();
4279
4280 // TODO: Handle non-unit strides when they become available.
4281 if (op.hasNonUnitStrides())
4282 return failure();
4283
4284 VectorType sliceVecTy = sourceValue.getType();
4285 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4286 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4287 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
4288 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
4289
4290 // Calcualte the destination element indices by enumerating all slice
4291 // positions within the destination and linearizing them. The enumeration
4292 // order is lexicographic which yields a sequence of monotonically
4293 // increasing linearized position indices.
4294 // Because the destination may have higher dimensionality then the slice,
4295 // we keep track of two overlapping sets of positions and offsets.
4296 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4297 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4298 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4299 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4300 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4301 MutableArrayRef<int64_t> currSlicePosition(
4302 currDestPosition.begin() + rankDifference, currDestPosition.end());
4303 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4304 offsets.end());
4305 do {
4306 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
4307 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
4308 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4309 "Invalid slice element");
4310 newValues[linearizedPosition] = *sliceValuesIt;
4311 ++sliceValuesIt;
4312 } while (succeeded(
4313 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
4314
4315 auto newAttr = DenseElementsAttr::get(destTy, newValues);
4316 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
4317 return success();
4318 }
4319};
4320
4321} // namespace
4322
4323void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4324 RewritePatternSet &results, MLIRContext *context) {
4325 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4326 InsertStridedSliceConstantFolder>(context);
4327}
4328
4329OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4330 if (getSourceVectorType() == getDestVectorType())
4331 return getValueToStore();
4332 return {};
4333}
4334
4335//===----------------------------------------------------------------------===//
4336// OuterProductOp
4337//===----------------------------------------------------------------------===//
4338
4339/// Build an op without mask, use the type of `acc` as the return type.
4340void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4341 Value lhs, Value rhs, Value acc) {
4342 result.addOperands({lhs, rhs, acc});
4343 result.addTypes(acc.getType());
4344}
4345
4346void OuterProductOp::print(OpAsmPrinter &p) {
4347 p << " " << getLhs() << ", " << getRhs();
4348 if (getAcc()) {
4349 p << ", " << getAcc();
4350 p.printOptionalAttrDict((*this)->getAttrs());
4351 }
4352 p << " : " << getLhs().getType() << ", " << getRhs().getType();
4353}
4354
4355ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4356 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4357 Type tLHS, tRHS;
4358 if (parser.parseOperandList(operandsInfo) ||
4359 parser.parseOptionalAttrDict(result.attributes) ||
4360 parser.parseColonType(tLHS) || parser.parseComma() ||
4361 parser.parseType(tRHS))
4362 return failure();
4363 if (operandsInfo.size() < 2)
4364 return parser.emitError(parser.getNameLoc(),
4365 "expected at least 2 operands");
4366 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4367 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4368 if (!vLHS)
4369 return parser.emitError(parser.getNameLoc(),
4370 "expected vector type for operand #1");
4371
4372 VectorType resType;
4373 if (vRHS) {
4374 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4375 vRHS.getScalableDims()[0]};
4376 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4377 vLHS.getElementType(), scalableDimsRes);
4378 } else {
4379 // Scalar RHS operand
4380 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4381 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4382 scalableDimsRes);
4383 }
4384
4385 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4386 result.attributes.append(
4387 OuterProductOp::getKindAttrName(result.name),
4388 CombiningKindAttr::get(result.getContext(),
4389 OuterProductOp::getDefaultKind()));
4390 }
4391
4392 return failure(
4393 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4394 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4395 (operandsInfo.size() > 2 &&
4396 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4397 parser.addTypeToList(resType, result.types));
4398}
4399
4400LogicalResult OuterProductOp::verify() {
4401 Type tRHS = getOperandTypeRHS();
4402 VectorType vLHS = getOperandVectorTypeLHS(),
4403 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4404 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4405
4406 if (vLHS.getRank() != 1)
4407 return emitOpError("expected 1-d vector for operand #1");
4408
4409 if (vRHS) {
4410 // Proper OUTER operation.
4411 if (vRHS.getRank() != 1)
4412 return emitOpError("expected 1-d vector for operand #2");
4413 if (vRES.getRank() != 2)
4414 return emitOpError("expected 2-d vector result");
4415 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4416 return emitOpError("expected #1 operand dim to match result dim #1");
4417 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4418 return emitOpError("expected #2 operand dim to match result dim #2");
4419 if (vLHS.isScalable() && !vRHS.isScalable()) {
4420 // This restriction reflects what's currently supported in terms of
4421 // scalable vectors. However, we could relax this if there's a use case.
4422 return emitOpError(
4423 "expected either both or only #2 operand dim to be scalable");
4424 }
4425 } else {
4426 // An AXPY operation.
4427 if (vRES.getRank() != 1)
4428 return emitOpError("expected 1-d vector result");
4429 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4430 return emitOpError("expected #1 operand dim to match result dim #1");
4431 }
4432
4433 if (vACC && vACC != vRES)
4434 return emitOpError("expected operand #3 of same type as result type");
4435
4436 if (!getKindAttr()) {
4437 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
4438 "'vector.kind<add>')");
4439 }
4440
4441 // Verify supported combining kind.
4442 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4443 return emitOpError("unsupported outerproduct type");
4444
4445 return success();
4446}
4447
4448// MaskableOpInterface methods.
4449
4450/// Returns the mask type expected by this operation. Mostly used for
4451/// verification purposes. It requires the operation to be vectorized."
4452Type OuterProductOp::getExpectedMaskType() {
4453 auto vecType = this->getResultVectorType();
4454 return VectorType::get(vecType.getShape(),
4455 IntegerType::get(vecType.getContext(), /*width=*/1),
4456 vecType.getScalableDims());
4457}
4458
4459//===----------------------------------------------------------------------===//
4460// ExtractStridedSliceOp
4461//===----------------------------------------------------------------------===//
4462
4463// Inference works as follows:
4464// 1. Add 'sizes' from prefix of dims in 'offsets'.
4465// 2. Add sizes from 'vectorType' for remaining dims.
4466// Scalable flags are inherited from 'vectorType'.
4467static Type inferStridedSliceOpResultType(VectorType vectorType,
4468 ArrayAttr offsets, ArrayAttr sizes,
4469 ArrayAttr strides) {
4470 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4472 shape.reserve(vectorType.getRank());
4473 unsigned idx = 0;
4474 for (unsigned e = offsets.size(); idx < e; ++idx)
4475 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4476 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4477 shape.push_back(vectorType.getShape()[idx]);
4478
4479 return VectorType::get(shape, vectorType.getElementType(),
4480 vectorType.getScalableDims());
4481}
4482
4483void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4484 Value source, ArrayRef<int64_t> offsets,
4485 ArrayRef<int64_t> sizes,
4486 ArrayRef<int64_t> strides) {
4487 result.addOperands(source);
4488 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4489 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4490 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4491 result.addTypes(
4492 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4493 offsetsAttr, sizesAttr, stridesAttr));
4494 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4495 offsetsAttr);
4496 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4497 sizesAttr);
4498 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4499 stridesAttr);
4500}
4501
4502LogicalResult ExtractStridedSliceOp::verify() {
4503 auto type = getSourceVectorType();
4504 auto offsets = getOffsetsAttr();
4505 auto sizes = getSizesAttr();
4506 auto strides = getStridesAttr();
4507 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4508 return emitOpError(
4509 "expected offsets, sizes and strides attributes of same size");
4510
4511 auto shape = type.getShape();
4512 auto offName = getOffsetsAttrName();
4513 auto sizesName = getSizesAttrName();
4514 auto stridesName = getStridesAttrName();
4515 if (failed(
4516 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4517 failed(
4518 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4519 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4520 stridesName)) ||
4521 failed(
4522 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4523 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4524 /*halfOpen=*/false,
4525 /*min=*/1)) ||
4526 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4527 /*max=*/1, stridesName,
4528 /*halfOpen=*/false)) ||
4529 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4530 shape, offName, sizesName,
4531 /*halfOpen=*/false)))
4532 return failure();
4533
4534 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4535 offsets, sizes, strides);
4536 if (getResult().getType() != resultType)
4537 return emitOpError("expected result type to be ") << resultType;
4538
4539 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4540 if (type.getScalableDims()[idx]) {
4541 auto inputDim = type.getShape()[idx];
4542 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4543 if (inputDim != inputSize)
4544 return emitOpError("expected size at idx=")
4545 << idx
4546 << (" to match the corresponding base size from the input "
4547 "vector (")
4548 << inputSize << (" vs ") << inputDim << (")");
4549 }
4550 }
4551
4552 return success();
4553}
4554
4555// When the source of ExtractStrided comes from a chain of InsertStrided ops try
4556// to use the source of the InsertStrided ops if we can detect that the
4557// extracted vector is a subset of one of the vector inserted.
4558static LogicalResult
4559foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4560 // Helper to extract integer out of ArrayAttr.
4561 auto getElement = [](ArrayAttr array, int idx) {
4562 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4563 };
4564 ArrayAttr extractOffsets = op.getOffsets();
4565 ArrayAttr extractStrides = op.getStrides();
4566 ArrayAttr extractSizes = op.getSizes();
4567 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4568 while (insertOp) {
4569 if (op.getSourceVectorType().getRank() !=
4570 insertOp.getSourceVectorType().getRank())
4571 return failure();
4572 ArrayAttr insertOffsets = insertOp.getOffsets();
4573 ArrayAttr insertStrides = insertOp.getStrides();
4574 // If the rank of extract is greater than the rank of insert, we are likely
4575 // extracting a partial chunk of the vector inserted.
4576 if (extractOffsets.size() > insertOffsets.size())
4577 return failure();
4578 bool patialoverlap = false;
4579 bool disjoint = false;
4580 SmallVector<int64_t, 4> offsetDiffs;
4581 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4582 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4583 return failure();
4584 int64_t start = getElement(insertOffsets, dim);
4585 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4586 int64_t offset = getElement(extractOffsets, dim);
4587 int64_t size = getElement(extractSizes, dim);
4588 // Check if the start of the extract offset is in the interval inserted.
4589 if (start <= offset && offset < end) {
4590 // If the extract interval overlaps but is not fully included we may
4591 // have a partial overlap that will prevent any folding.
4592 if (offset + size > end)
4593 patialoverlap = true;
4594 offsetDiffs.push_back(offset - start);
4595 continue;
4596 }
4597 disjoint = true;
4598 break;
4599 }
4600 // The extract element chunk is a subset of the insert element.
4601 if (!disjoint && !patialoverlap) {
4602 op.setOperand(insertOp.getValueToStore());
4603 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4604 OpBuilder b(op.getContext());
4605 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4606 return success();
4607 }
4608 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4609 // in the insert chain.
4610 if (disjoint)
4611 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4612 else {
4613 // The extracted vector partially overlap the inserted vector, we cannot
4614 // fold.
4615 return failure();
4616 }
4617 }
4618 return failure();
4619}
4620
4621// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4622static OpFoldResult
4624 Attribute foldInput) {
4625
4626 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4627 if (!dense)
4628 return {};
4629
4630 // TODO: Handle non-unit strides when they become available.
4631 if (op.hasNonUnitStrides())
4632 return {};
4633
4634 VectorType sourceVecTy = op.getSourceVectorType();
4635 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4636 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4637
4638 VectorType sliceVecTy = op.getType();
4639 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4640 int64_t rank = sliceVecTy.getRank();
4641
4642 // Expand offsets and sizes to match the vector rank.
4643 SmallVector<int64_t, 4> offsets(rank, 0);
4644 copy(getI64SubArray(op.getOffsets()), offsets.begin());
4645
4646 SmallVector<int64_t, 4> sizes(sourceShape);
4647 copy(getI64SubArray(op.getSizes()), sizes.begin());
4648
4649 // Calculate the slice elements by enumerating all slice positions and
4650 // linearizing them. The enumeration order is lexicographic which yields a
4651 // sequence of monotonically increasing linearized position indices.
4652 const auto denseValuesBegin = dense.value_begin<Attribute>();
4653 SmallVector<Attribute> sliceValues;
4654 sliceValues.reserve(sliceVecTy.getNumElements());
4655 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4656 do {
4657 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4658 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4659 "Invalid index");
4660 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4661 } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4662
4663 assert(static_cast<int64_t>(sliceValues.size()) ==
4664 sliceVecTy.getNumElements() &&
4665 "Invalid number of slice elements");
4666 return DenseElementsAttr::get(sliceVecTy, sliceValues);
4667}
4668
4669OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4670 if (getSourceVectorType() == getResult().getType())
4671 return getSource();
4672 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4673 return getResult();
4674
4675 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4676 if (auto splat =
4677 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4678 return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4679
4680 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4681 return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4682}
4683
4684void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4685 populateFromInt64AttrArray(getOffsets(), results);
4686}
4687
4688namespace {
4689
4690// Pattern to rewrite nested ExtractStridedSliceOp into a single one.
4691//
4692// Example:
4693//
4694// %0 = vector.extract_strided_slice %arg0
4695// {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]}
4696// : vector<4x8x16xf32> to vector<3x4x16xf32>
4697// %1 = vector.extract_strided_slice %0
4698// {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
4699// : vector<3x4x16xf32> to vector<2x2x16xf32>
4700//
4701// to
4702//
4703// %1 = vector.extract_strided_slice %arg0
4704// {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
4705// : vector<4x8x16xf32> to vector<2x2x16xf32>
4706class StridedSliceFolder final
4707 : public OpRewritePattern<ExtractStridedSliceOp> {
4708public:
4709 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4710
4711 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4712 PatternRewriter &rewriter) const override {
4713 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4714 if (!firstOp)
4715 return failure();
4716
4717 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4718 return failure();
4719
4720 SmallVector<int64_t> firstOffsets = getI64SubArray(firstOp.getOffsets());
4721 SmallVector<int64_t> firstSizes = getI64SubArray(firstOp.getSizes());
4722 SmallVector<int64_t> secondOffsets = getI64SubArray(secondOp.getOffsets());
4723 SmallVector<int64_t> secondSizes = getI64SubArray(secondOp.getSizes());
4724
4725 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4726 SmallVector<int64_t> combinedOffsets(newRank, 0);
4727 SmallVector<int64_t> combinedSizes(newRank);
4728 ArrayRef<int64_t> firstSourceShape =
4729 firstOp.getSourceVectorType().getShape();
4730 for (unsigned i = 0; i < newRank; ++i) {
4731 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4732 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4733 combinedOffsets[i] = off1 + off2;
4734
4735 if (i < secondSizes.size()) {
4736 combinedSizes[i] = secondSizes[i];
4737 } else if (i < firstSizes.size()) {
4738 combinedSizes[i] = firstSizes[i];
4739 } else {
4740 combinedSizes[i] = firstSourceShape[i];
4741 }
4742 }
4743
4744 SmallVector<int64_t> combinedStrides(newRank, 1);
4745 rewriter.replaceOpWithNewOp<ExtractStridedSliceOp>(
4746 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4747 combinedStrides);
4748 return success();
4749 }
4750};
4751
4752// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4753// CreateMaskOp.
4754//
4755// Example:
4756//
4757// %mask = vector.create_mask %ub : vector<16xi1>
4758// %slice = vector.extract_strided_slice [%offset] [8] [1]
4759//
4760// to
4761//
4762// %new_ub = arith.subi %ub, %offset
4763// %mask = vector.create_mask %new_ub : vector<8xi1>
4764class StridedSliceCreateMaskFolder final
4765 : public OpRewritePattern<ExtractStridedSliceOp> {
4766 using Base::Base;
4767
4768public:
4769 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4770 PatternRewriter &rewriter) const override {
4771 Location loc = extractStridedSliceOp.getLoc();
4772 // Return if 'extractStridedSliceOp' operand is not defined by a
4773 // CreateMaskOp.
4774 auto createMaskOp =
4775 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4776 if (!createMaskOp)
4777 return failure();
4778 // Return if 'extractStridedSliceOp' has non-unit strides.
4779 if (extractStridedSliceOp.hasNonUnitStrides())
4780 return failure();
4781 // Gather constant mask dimension sizes.
4782 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4783 // Gather strided slice offsets and sizes.
4784 SmallVector<int64_t> sliceOffsets;
4785 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4786 sliceOffsets);
4787 SmallVector<int64_t> sliceSizes;
4788 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4789
4790 // Compute slice of vector mask region.
4791 SmallVector<Value> sliceMaskDimSizes;
4792 sliceMaskDimSizes.reserve(maskDimSizes.size());
4793 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4794 // only iterate on the leading dim sizes. The tail accounts for the
4795 // remaining dim sizes.
4796 for (auto [maskDimSize, sliceOffset, sliceSize] :
4797 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4798 // No need to clamp on min/max values, because create_mask has clamping
4799 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4800 // greater than the vector dim size.
4801 IntegerAttr offsetAttr =
4802 rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4803 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4804 Value sliceMaskDimSize =
4805 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4806 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4807 }
4808 // Add unchanged dimensions.
4809 llvm::append_range(
4810 sliceMaskDimSizes,
4811 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4812 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4813 // region.
4814 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4815 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4816 sliceMaskDimSizes);
4817 return success();
4818 }
4819};
4820
4821// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4822// ConstantMaskOp.
4823class StridedSliceConstantMaskFolder final
4824 : public OpRewritePattern<ExtractStridedSliceOp> {
4825public:
4826 using Base::Base;
4827
4828 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4829 PatternRewriter &rewriter) const override {
4830 // Return if 'extractStridedSliceOp' operand is not defined by a
4831 // ConstantMaskOp.
4832 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4833 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4834 if (!constantMaskOp)
4835 return failure();
4836 // Return if 'extractStridedSliceOp' has non-unit strides.
4837 if (extractStridedSliceOp.hasNonUnitStrides())
4838 return failure();
4839 // Gather constant mask dimension sizes.
4840 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4841 // Gather strided slice offsets and sizes.
4842 SmallVector<int64_t> sliceOffsets;
4843 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4844 sliceOffsets);
4845 SmallVector<int64_t> sliceSizes;
4846 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4847
4848 // Compute slice of vector mask region.
4849 SmallVector<int64_t> sliceMaskDimSizes;
4850 sliceMaskDimSizes.reserve(maskDimSizes.size());
4851 for (auto [maskDimSize, sliceOffset, sliceSize] :
4852 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4853 int64_t sliceMaskDimSize = std::max(
4854 static_cast<int64_t>(0),
4855 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4856 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4857 }
4858 // Add unchanged dimensions.
4859 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4860 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4861 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4862 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4863 // region is a conjunction of mask dim intervals).
4864 if (llvm::is_contained(sliceMaskDimSizes, 0))
4865 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4866
4867 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4868 // region.
4869 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4870 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4871 sliceMaskDimSizes);
4872 return success();
4873 }
4874};
4875
4876// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4877// BroadcastOp(ExtractStrideSliceOp).
4878class StridedSliceBroadcast final
4879 : public OpRewritePattern<ExtractStridedSliceOp> {
4880public:
4881 using Base::Base;
4882
4883 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4884 PatternRewriter &rewriter) const override {
4885 auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4886 if (!broadcast)
4887 return failure();
4888 auto srcVecType =
4889 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4890 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4891 auto dstVecType = llvm::cast<VectorType>(op.getType());
4892 unsigned dstRank = dstVecType.getRank();
4893 unsigned rankDiff = dstRank - srcRank;
4894 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4895 // (n -> m with n > m). If they are originally both broadcasted *and*
4896 // sliced, this can be simplified to just broadcasting.
4897 bool needsSlice = false;
4898 for (unsigned i = 0; i < srcRank; i++) {
4899 if (srcVecType.getDimSize(i) != 1 &&
4900 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4901 needsSlice = true;
4902 break;
4903 }
4904 }
4905 Value source = broadcast.getSource();
4906 if (needsSlice) {
4907 SmallVector<int64_t> offsets =
4908 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4909 SmallVector<int64_t> sizes =
4910 getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4911 for (unsigned i = 0; i < srcRank; i++) {
4912 if (srcVecType.getDimSize(i) == 1) {
4913 // In case this dimension was broadcasted *and* sliced, the offset
4914 // and size need to be updated now that there is no broadcast before
4915 // the slice.
4916 offsets[i] = 0;
4917 sizes[i] = 1;
4918 }
4919 }
4920 source = ExtractStridedSliceOp::create(
4921 rewriter, op->getLoc(), source, offsets, sizes,
4922 getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4923 }
4924 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4925 return success();
4926 }
4927};
4928
4929/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4930class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4931public:
4932 using Base::Base;
4933
4934 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4935 PatternRewriter &rewriter) const override {
4936
4937 Value splat = getScalarSplatSource(op.getSource());
4938 if (!splat)
4939 return failure();
4940 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4941 return success();
4942 }
4943};
4944
4945/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4946/// slice is contiguous, into extract and shape_cast.
4947///
4948/// Example:
4949/// Before:
4950/// %1 = vector.extract_strided_slice %arg0 {
4951/// offsets = [0, 0, 0, 0, 0],
4952/// sizes = [1, 1, 1, 1, 8],
4953/// strides = [1, 1, 1, 1, 1]
4954/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4955/// After:
4956/// %0 = vector.extract %arg0[0, 0, 0, 0]
4957/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4958/// %1 = vector.shape_cast %0
4959/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4960///
4961class ContiguousExtractStridedSliceToExtract final
4962 : public OpRewritePattern<ExtractStridedSliceOp> {
4963public:
4964 using Base::Base;
4965
4966 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4967 PatternRewriter &rewriter) const override {
4968 if (op.hasNonUnitStrides())
4969 return failure();
4970 Value source = op.getOperand();
4971 auto sourceType = cast<VectorType>(source.getType());
4972 if (sourceType.isScalable() || sourceType.getRank() == 0)
4973 return failure();
4974
4975 // Compute the number of offsets to pass to ExtractOp::build. That is the
4976 // difference between the source rank and the desired slice rank. We walk
4977 // the dimensions from innermost out, and stop when the next slice dimension
4978 // is not full-size.
4979 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4980 int numOffsets;
4981 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4982 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4983 break;
4984 }
4985
4986 // If the created extract op would have no offsets, then this whole
4987 // extract_strided_slice is the identity and should have been handled by
4988 // other canonicalizations.
4989 if (numOffsets == 0)
4990 return failure();
4991
4992 // If not even the inner-most dimension is full-size, this op can't be
4993 // rewritten as an ExtractOp.
4994 if (numOffsets == sourceType.getRank() &&
4995 static_cast<int>(sizes.size()) == sourceType.getRank())
4996 return failure();
4997
4998 // The outer dimensions must have unit size.
4999 for (int i = 0; i < numOffsets; ++i) {
5000 if (sizes[i] != 1)
5001 return failure();
5002 }
5003
5004 // Avoid generating slices that have leading unit dimensions. The shape_cast
5005 // op that we create below would take bad generic fallback patterns
5006 // (ShapeCastOpRewritePattern).
5007 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
5008 sizes[numOffsets] == 1) {
5009 ++numOffsets;
5010 }
5011
5012 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
5013 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
5014 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
5015 extractOffsets);
5016 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
5017 return success();
5018 }
5019};
5020
5021} // namespace
5022
5023void ExtractStridedSliceOp::getCanonicalizationPatterns(
5024 RewritePatternSet &results, MLIRContext *context) {
5025 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
5026 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
5027 results.add<StridedSliceFolder, StridedSliceCreateMaskFolder,
5028 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
5029 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
5030 context);
5031}
5032
5033//===----------------------------------------------------------------------===//
5034// TransferReadOp
5035//===----------------------------------------------------------------------===//
5036
5037/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
5038void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5039 VectorType vectorType, Value source,
5040 ValueRange indices, std::optional<Value> padding,
5041 AffineMapAttr permutationMapAttr,
5042 /*optional*/ ArrayAttr inBoundsAttr) {
5043
5044 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5045 if (!padding)
5046 padding = ub::PoisonOp::create(builder, result.location, elemType);
5047 build(builder, result, vectorType, source, indices, permutationMapAttr,
5048 *padding, /*mask=*/Value(), inBoundsAttr);
5049}
5050
5051/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
5052void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5053 VectorType vectorType, Value source,
5054 ValueRange indices, std::optional<Value> padding,
5055 AffineMap permutationMap,
5056 std::optional<ArrayRef<bool>> inBounds) {
5057 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5058 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5059 ? builder.getBoolArrayAttr(inBounds.value())
5060 : builder.getBoolArrayAttr(
5061 SmallVector<bool>(vectorType.getRank(), false));
5062 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5063 if (!padding)
5064 padding = ub::PoisonOp::create(builder, result.location, elemType);
5065 build(builder, result, vectorType, source, indices, *padding,
5066 permutationMapAttr, inBoundsAttr);
5067}
5068
5069/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
5070void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5071 VectorType vectorType, Value source,
5072 ValueRange indices, std::optional<Value> padding,
5073 std::optional<ArrayRef<bool>> inBounds) {
5074 AffineMap permutationMap = getTransferMinorIdentityMap(
5075 llvm::cast<ShapedType>(source.getType()), vectorType);
5076 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5077 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5078 ? builder.getBoolArrayAttr(inBounds.value())
5079 : builder.getBoolArrayAttr(
5080 SmallVector<bool>(vectorType.getRank(), false));
5081 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5082 if (!padding)
5083 padding = ub::PoisonOp::create(builder, result.location, elemType);
5084 build(builder, result, vectorType, source, indices, permutationMapAttr,
5085 *padding,
5086 /*mask=*/Value(), inBoundsAttr);
5087}
5088
5089template <typename EmitFun>
5090static LogicalResult verifyPermutationMap(AffineMap permutationMap,
5091 EmitFun emitOpError) {
5092 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
5093 for (auto expr : permutationMap.getResults()) {
5094 auto dim = dyn_cast<AffineDimExpr>(expr);
5095 auto zero = dyn_cast<AffineConstantExpr>(expr);
5096 if (zero) {
5097 if (zero.getValue() != 0) {
5098 return emitOpError(
5099 "requires a projected permutation_map (at most one dim or the zero "
5100 "constant can appear in each result)");
5101 }
5102 continue;
5103 }
5104 if (!dim) {
5105 return emitOpError("requires a projected permutation_map (at most one "
5106 "dim or the zero constant can appear in each result)");
5107 }
5108 if (seen[dim.getPosition()]) {
5109 return emitOpError(
5110 "requires a permutation_map that is a permutation (found one dim "
5111 "used more than once)");
5112 }
5113 seen[dim.getPosition()] = true;
5114 }
5115 return success();
5116}
5117
5118static LogicalResult
5119verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
5120 VectorType vectorType, VectorType maskType,
5121 VectorType inferredMaskType, AffineMap permutationMap,
5122 ArrayAttr inBounds) {
5123 if (op->hasAttr("masked")) {
5124 return op->emitOpError("masked attribute has been removed. "
5125 "Use in_bounds instead.");
5126 }
5127
5128 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
5129 return op->emitOpError(
5130 "requires source to be a memref or ranked tensor type");
5131
5132 auto elementType = shapedType.getElementType();
5133 DataLayout dataLayout = DataLayout::closest(op);
5134 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
5135 // Memref or tensor has vector element type.
5136 unsigned sourceVecSize =
5137 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
5138 vectorElementType.getShape().back();
5139 unsigned resultVecSize =
5140 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
5141 vectorType.getShape().back();
5142 if (resultVecSize % sourceVecSize != 0)
5143 return op->emitOpError(
5144 "requires the bitwidth of the minor 1-D vector to be an integral "
5145 "multiple of the bitwidth of the minor 1-D vector of the source");
5146
5147 unsigned sourceVecEltRank = vectorElementType.getRank();
5148 unsigned resultVecRank = vectorType.getRank();
5149 if (sourceVecEltRank > resultVecRank)
5150 return op->emitOpError(
5151 "requires source vector element and vector result ranks to match.");
5152 unsigned rankOffset = resultVecRank - sourceVecEltRank;
5153 // Check that permutation map results match 'rankOffset' of vector type.
5154 if (permutationMap.getNumResults() != rankOffset)
5155 return op->emitOpError("requires a permutation_map with result dims of "
5156 "the same rank as the vector type");
5157
5158 if (maskType)
5159 return op->emitOpError("does not support masks with vector element type");
5160 } else {
5161 // Memref or tensor has scalar element type.
5162 unsigned minorSize =
5163 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
5164 unsigned resultVecSize =
5165 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
5166 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
5167 return op->emitOpError(
5168 "requires the bitwidth of the minor 1-D vector to be an integral "
5169 "multiple of the bitwidth of the source element type");
5170
5171 // Check that permutation map results match rank of vector type.
5172 if (permutationMap.getNumResults() != vectorType.getRank())
5173 return op->emitOpError("requires a permutation_map with result dims of "
5174 "the same rank as the vector type");
5175 }
5176
5177 if (permutationMap.getNumSymbols() != 0)
5178 return op->emitOpError("requires permutation_map without symbols");
5179
5180 if (permutationMap.getNumInputs() != shapedType.getRank())
5181 return op->emitOpError("requires a permutation_map with input dims of the "
5182 "same rank as the source type");
5183
5184 if (maskType && maskType != inferredMaskType)
5185 return op->emitOpError("inferred mask type (")
5186 << inferredMaskType << ") and mask operand type (" << maskType
5187 << ") don't match";
5188
5189 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
5190 return op->emitOpError("expects the in_bounds attr of same rank "
5191 "as permutation_map results: ")
5192 << AffineMapAttr::get(permutationMap)
5193 << " vs inBounds of size: " << inBounds.size();
5194
5195 return success();
5196}
5197
5198static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
5199 SmallVector<StringRef, 3> elidedAttrs;
5200 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5201 if (op.getPermutationMap().isMinorIdentity())
5202 elidedAttrs.push_back(op.getPermutationMapAttrName());
5203 // Elide in_bounds attribute if all dims are out-of-bounds.
5204 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
5205 elidedAttrs.push_back(op.getInBoundsAttrName());
5206 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5207}
5208
5209void TransferReadOp::print(OpAsmPrinter &p) {
5210 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
5211 if (getMask())
5212 p << ", " << getMask();
5213 printTransferAttrs(p, *this);
5214 p << " : " << getShapedType() << ", " << getVectorType();
5215}
5216
5217VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
5218 AffineMap permMap) {
5219 auto i1Type = IntegerType::get(permMap.getContext(), 1);
5220 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
5221 assert(invPermMap && "Inversed permutation map couldn't be computed");
5222 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
5223
5224 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
5225 // 0-D mask into a single-element 1-D mask.
5226 if (maskShape.empty())
5227 maskShape.push_back(1);
5228
5229 SmallVector<bool> scalableDims =
5230 applyPermutationMap(invPermMap, vecType.getScalableDims());
5231
5232 return VectorType::get(maskShape, i1Type, scalableDims);
5233}
5234
5235ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
5236 auto &builder = parser.getBuilder();
5237 SMLoc typesLoc;
5243 // Parsing with support for paddingValue.
5244 if (parser.parseOperand(sourceInfo) ||
5246 parser.parseComma() || parser.parseOperand(paddingInfo))
5247 return failure();
5248 ParseResult hasMask = parser.parseOptionalComma();
5249 if (hasMask.succeeded()) {
5250 if (parser.parseOperand(maskInfo))
5251 return failure();
5252 }
5253 if (parser.parseOptionalAttrDict(result.attributes) ||
5254 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5255 return failure();
5256 if (types.size() != 2)
5257 return parser.emitError(typesLoc, "requires two types");
5258 auto indexType = builder.getIndexType();
5259 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5260 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5261 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5262 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5263 if (!vectorType)
5264 return parser.emitError(typesLoc, "requires vector type");
5265 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
5266 Attribute permMapAttr = result.attributes.get(permMapAttrName);
5267 AffineMap permMap;
5268 if (!permMapAttr) {
5269 if (shapedType.getRank() <
5270 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5271 return parser.emitError(typesLoc,
5272 "expected a custom permutation_map when "
5273 "rank(source) != rank(destination)");
5274 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5275 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5276 } else {
5277 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5278 }
5279 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
5280 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5281 if (!inBoundsAttr) {
5282 result.addAttribute(inBoundsAttrName,
5283 builder.getBoolArrayAttr(
5284 SmallVector<bool>(permMap.getNumResults(), false)));
5285 }
5286 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5287 parser.resolveOperands(indexInfo, indexType, result.operands) ||
5288 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
5289 result.operands))
5290 return failure();
5291 if (hasMask.succeeded()) {
5292 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5293 return parser.emitError(
5294 maskInfo.location, "does not support masks with vector element type");
5295 if (vectorType.getRank() != permMap.getNumResults()) {
5296 return parser.emitError(typesLoc,
5297 "expected the same rank for the vector and the "
5298 "results of the permutation map");
5299 }
5300 // Instead of adding the mask type as an op type, compute it based on the
5301 // vector type and the permutation map (to keep the type signature small).
5302 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5303 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5304 return failure();
5305 }
5306 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5307 builder.getDenseI32ArrayAttr(
5308 {1, static_cast<int32_t>(indexInfo.size()), 1,
5309 static_cast<int32_t>(hasMask.succeeded())}));
5310 return parser.addTypeToList(vectorType, result.types);
5311}
5312
5313LogicalResult TransferReadOp::verify() {
5314 // Consistency of elemental types in source and vector.
5315 ShapedType shapedType = getShapedType();
5316 VectorType vectorType = getVectorType();
5317 VectorType maskType = getMaskType();
5318 auto paddingType = getPadding().getType();
5319 auto permutationMap = getPermutationMap();
5320 VectorType inferredMaskType =
5321 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5322 : VectorType();
5323 auto sourceElementType = shapedType.getElementType();
5324
5325 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
5326 return emitOpError("requires ") << shapedType.getRank() << " indices";
5327
5328 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5329 shapedType, vectorType, maskType,
5330 inferredMaskType, permutationMap, getInBounds())))
5331 return failure();
5332
5333 if (auto sourceVectorElementType =
5334 llvm::dyn_cast<VectorType>(sourceElementType)) {
5335 // Source has vector element type.
5336 // Check that 'sourceVectorElementType' and 'paddingType' types match.
5337 if (sourceVectorElementType != paddingType)
5338 return emitOpError(
5339 "requires source element type and padding type to match.");
5340
5341 } else {
5342 // Check that 'paddingType' is valid to store in a vector type.
5343 if (!VectorType::isValidElementType(paddingType))
5344 return emitOpError("requires valid padding vector elemental type");
5345
5346 // Check that padding type and vector element types match.
5347 if (paddingType != sourceElementType)
5348 return emitOpError(
5349 "requires formal padding and source of the same elemental type");
5350 }
5351
5352 return verifyPermutationMap(permutationMap,
5353 [&](Twine t) { return emitOpError(t); });
5354}
5355
5356// MaskableOpInterface methods.
5357
5358/// Returns the mask type expected by this operation. Mostly used for
5359/// verification purposes. It requires the operation to be vectorized."
5360Type TransferReadOp::getExpectedMaskType() {
5361 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5362}
5363
5364//===----------------------------------------------------------------------===//
5365// TransferReadOp: VectorTransferOpInterface methods.
5366//===----------------------------------------------------------------------===//
5367VectorType TransferReadOp::getVectorType() {
5368 return cast<VectorType>(getVector().getType());
5369}
5370
5371template <typename TransferOp>
5372static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
5373 // TODO: support more aggressive createOrFold on:
5374 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
5375 if (op.getShapedType().isDynamicDim(indicesIdx))
5376 return false;
5377 Value index = op.getIndices()[indicesIdx];
5378 std::optional<int64_t> cstOp = getConstantIntValue(index);
5379 if (!cstOp.has_value())
5380 return false;
5381
5382 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5383 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5384
5385 return cstOp.value() + vectorSize <= sourceSize;
5386}
5387
5388template <typename TransferOp>
5389static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
5390 // TODO: support 0-d corner case.
5391 // TODO: Be less conservative.
5392 if (op.getTransferRank() == 0)
5393 return failure();
5394 AffineMap permutationMap = op.getPermutationMap();
5395 bool changed = false;
5396 SmallVector<bool, 4> newInBounds;
5397 newInBounds.reserve(op.getTransferRank());
5398 // Idxs of non-bcast dims - used when analysing bcast dims.
5399 SmallVector<unsigned> nonBcastDims;
5400
5401 // 1. Process non-broadcast dims
5402 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5403 // 1.1. Already marked as in-bounds, nothing to see here.
5404 if (op.isDimInBounds(i)) {
5405 newInBounds.push_back(true);
5406 continue;
5407 }
5408 // 1.2. Currently out-of-bounds, check whether we can statically determine
5409 // it is inBounds.
5410 bool inBounds = false;
5411 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5412 if (dimExpr) {
5413 inBounds = isInBounds(op, /*resultIdx=*/i,
5414 /*indicesIdx=*/dimExpr.getPosition());
5415 nonBcastDims.push_back(i);
5416 }
5417
5418 newInBounds.push_back(inBounds);
5419 // We commit the pattern if it is "more inbounds".
5420 changed |= inBounds;
5421 }
5422
5423 // 2. Handle broadcast dims
5424 // If all non-broadcast dims are "in bounds", then all bcast dims should be
5425 // "in bounds" as well.
5426 bool allNonBcastDimsInBounds = llvm::all_of(
5427 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5428 if (allNonBcastDimsInBounds) {
5429 for (size_t idx : permutationMap.getBroadcastDims()) {
5430 changed |= !newInBounds[idx];
5431 newInBounds[idx] = true;
5432 }
5433 }
5434
5435 if (!changed)
5436 return failure();
5437 // OpBuilder is only used as a helper to build an I64ArrayAttr.
5438 OpBuilder b(op.getContext());
5439 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5440 return success();
5441}
5442
5443template <typename TransferOp>
5444static LogicalResult foldTransferFullMask(TransferOp op) {
5445 auto mask = op.getMask();
5446 if (!mask)
5447 return failure();
5448
5450 return failure();
5451
5452 op.getMaskMutable().clear();
5453 return success();
5454}
5455
5456/// When the vector type is `vector<1xT>`, the permutation map is irrelevant:
5457/// the single vector lane always has iteration offset 0, so the element is at
5458/// `indices` regardless of which source dimension the map points at. Replace
5459/// with the minor identity to unblock lowering to vector.load / vector.store.
5460template <typename TransferOp>
5461static LogicalResult foldSize1TransferPermutationMap(TransferOp op) {
5462 VectorType vecType = op.getVectorType();
5463 if (vecType.getRank() != 1 || vecType.getShape()[0] != 1 ||
5464 vecType.isScalable())
5465 return failure();
5466
5467 AffineMap map = op.getPermutationMap();
5468 if (map.isMinorIdentity())
5469 return failure();
5470
5471 int64_t srcRank = op.getShapedType().getRank();
5472 if (srcRank < 1)
5473 return failure();
5474
5475 AffineMap minorIdentity =
5476 AffineMap::getMinorIdentityMap(srcRank, 1, op.getContext());
5477 op.setPermutationMapAttr(AffineMapAttr::get(minorIdentity));
5478 return success();
5479}
5480
5481/// ```
5482/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5483/// : vector<1x4xf32>, tensor<4x4xf32>
5484/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5485/// : tensor<4x4xf32>, vector<1x4xf32>
5486/// ```
5487/// -> Folds into
5488/// ```
5489/// %v0
5490/// ```
5491static Value foldRAW(TransferReadOp readOp) {
5492 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5493 return {};
5494 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5495 while (defWrite) {
5496 if (checkSameValueRAW(defWrite, readOp))
5497 return defWrite.getVector();
5499 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5500 cast<VectorTransferOpInterface>(readOp.getOperation())))
5501 break;
5502 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5503 }
5504 return {};
5505}
5506
5507OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5508 if (Value vec = foldRAW(*this))
5509 return vec;
5510 /// transfer_read(memrefcast) -> transfer_read
5511 if (succeeded(foldTransferInBoundsAttribute(*this)))
5512 return getResult();
5513 if (succeeded(foldTransferFullMask(*this)))
5514 return getResult();
5515 if (succeeded(foldSize1TransferPermutationMap(*this)))
5516 return getResult();
5517 if (succeeded(memref::foldMemRefCast(*this)))
5518 return getResult();
5519 if (succeeded(tensor::foldTensorCast(*this)))
5520 return getResult();
5521 return OpFoldResult();
5522}
5523
5524std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5525 return llvm::to_vector<4>(getVectorType().getShape());
5526}
5527
5528void TransferReadOp::getEffects(
5529 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5530 &effects) {
5531 if (llvm::isa<MemRefType>(getShapedType()))
5532 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5533 SideEffects::DefaultResource::get());
5534}
5535
5536Speculation::Speculatability TransferReadOp::getSpeculatability() {
5537 if (hasPureTensorSemantics())
5540}
5541
5542/// Given a projected permutation, inverse an affine map, making the unused dims
5543/// 0 in the result.
5544static AffineMap inverseWithUnusedDims(AffineMap map) {
5545 assert(map.isProjectedPermutation() &&
5546 "expected a projected permutation map");
5547 SmallVector<AffineExpr> results(map.getNumInputs(),
5549 for (auto [idx, result] : llvm::enumerate(map.getResults())) {
5550 // We should only have dim exprs because this is a projected permutation.
5551 int64_t pos = cast<AffineDimExpr>(result).getPosition();
5552 results[pos] = getAffineDimExpr(idx, map.getContext());
5553 }
5554 return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
5555 results, map.getContext());
5556}
5557
5558namespace {
5559/// Store to load forwarding for transfer operations with permuation maps.
5560/// Even if the permutation maps are different we can still propagate the store
5561/// into the load if the size of the dimensions read and written match. Then we
5562/// can replace the transfer_read + transfer_write by vector.broadcast and
5563/// vector.transpose.
5564/// Example:
5565/// ```
5566/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5567/// {in_bounds = [true, true],
5568/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5569/// vector<4x1xf32>, tensor<4x4x4xf32>
5570/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5571/// {in_bounds = [true, true, true, true],
5572/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5573/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5574/// ```
5575/// To:
5576/// ```
5577/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5578/// %r = vector.transpose %0, [3, 0, 2, 1] :
5579/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5580/// ```
5581struct TransferReadAfterWriteToBroadcast
5582 : public OpRewritePattern<TransferReadOp> {
5583 using Base::Base;
5584
5585 LogicalResult matchAndRewrite(TransferReadOp readOp,
5586 PatternRewriter &rewriter) const override {
5587 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5588 if (!defWrite)
5589 return failure();
5590 // Bail if we need an alias analysis.
5591 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5592 return failure();
5593 // Bail in the masked case (too complex atm and needed to properly account
5594 // for padding).
5595 if (readOp.getMask() || defWrite.getMask())
5596 return failure();
5597 // If indices are not the same a shift may be required, bail.
5598 if (readOp.getIndices() != defWrite.getIndices())
5599 return failure();
5600 // Bail if we need a bounds analysis.
5601 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5602 return failure();
5603 // TODO: If the written transfer chunk is a superset of the read transfer
5604 // chunk we could do an extract_strided_slice.
5605 if (readOp.getTransferChunkAccessed() !=
5606 defWrite.getTransferChunkAccessed())
5607 return failure();
5608 // WriteMap: tensor -> w_vec
5609 // ReadMap: tensor -> r_vec
5610 //
5611 // inv(WriteMap): w_vec -> tensor
5612 // inv(WriteMap) o ReadMap: w_vec -> r_vec
5613 AffineMap readMap = readOp.getPermutationMap();
5614 AffineMap writeMap = defWrite.getPermutationMap();
5615 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5616 AffineMap composedMap = readMap.compose(invWriteMap);
5617 // If there are any unused dims in the composedMap, we have to drop some
5618 // unit dims from the written vector before we can do transpose(broadcast).
5619 // TODO: Support this case.
5620 if (getUnusedDimsBitVector(composedMap).any())
5621 return failure();
5622 // readVec = transpose(broadcast(writeVec))
5623 //
5624 // Build a transpose permutation for the above transpose operation.
5625 //
5626 // Treat the composed map as having extra leading dimensions which are
5627 // the broadcasted dimensions, and treat the zeros as these new broadcasted
5628 // dimensions.
5629 SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
5630 int64_t numBroadcastedDims = broadcastedDims.size();
5631 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5632 invPerm.resize(composedMap.getNumResults());
5633 for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
5634 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5635 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5636 invPerm[effectiveDim] = idx;
5637 }
5638 }
5639 // Applying the inverse permutation on the readVecTy will give us the
5640 // broadcast result type.
5641 VectorType readVecTy = readOp.getVectorType();
5642 SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
5643 auto broadcastedVecTy =
5644 VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
5645 readVecTy.getElementType(),
5646 applyPermutation(readVecTy.getScalableDims(), invPerm));
5647 // Build the transpose(broadcast) transformation.
5648 Value vec = defWrite.getVector();
5649 Location loc = readOp.getLoc();
5650 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5651 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
5652 return success();
5653 }
5654};
5655} // namespace
5656
5657void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5658 MLIRContext *context) {
5659 results.add<TransferReadAfterWriteToBroadcast>(context);
5660}
5661
5662FailureOr<std::optional<SmallVector<Value>>>
5663TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5664 if (!hasPureBufferSemantics())
5665 return failure();
5667 getResult());
5668}
5669
5670//===----------------------------------------------------------------------===//
5671// TransferWriteOp
5672//===----------------------------------------------------------------------===//
5673
5674/// 1. Builder with type inference.
5675void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5676 Value vector, Value dest, ValueRange indices,
5677 AffineMapAttr permutationMapAttr,
5678 /*optional*/ Value mask,
5679 /*optional*/ ArrayAttr inBoundsAttr) {
5680 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5681 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5682 mask, inBoundsAttr);
5683}
5684
5685/// 2. Builder with type inference that sets an empty mask (variant with attrs).
5686void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5687 Value vector, Value dest, ValueRange indices,
5688 AffineMapAttr permutationMapAttr,
5689 /*optional*/ ArrayAttr inBoundsAttr) {
5690 build(builder, result, vector, dest, indices, permutationMapAttr,
5691 /*mask=*/Value(), inBoundsAttr);
5692}
5693
5694/// 3. Builder with type inference that sets an empty mask (variant without
5695/// attrs)
5696void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5697 Value vector, Value dest, ValueRange indices,
5698 AffineMap permutationMap,
5699 std::optional<ArrayRef<bool>> inBounds) {
5700 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5701 auto inBoundsAttr =
5702 (inBounds && !inBounds.value().empty())
5703 ? builder.getBoolArrayAttr(inBounds.value())
5704 : builder.getBoolArrayAttr(SmallVector<bool>(
5705 llvm::cast<VectorType>(vector.getType()).getRank(), false));
5706 build(builder, result, vector, dest, indices, permutationMapAttr,
5707 /*mask=*/Value(), inBoundsAttr);
5708}
5709
5710/// 4. Builder with type inference that sets an empty mask and sets permutation
5711/// map to 'getMinorIdentityMap'.
5712void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5713 Value vector, Value dest, ValueRange indices,
5714 std::optional<ArrayRef<bool>> inBounds) {
5715 auto vectorType = llvm::cast<VectorType>(vector.getType());
5716 AffineMap permutationMap = getTransferMinorIdentityMap(
5717 llvm::cast<ShapedType>(dest.getType()), vectorType);
5718 build(builder, result, vector, dest, indices, permutationMap, inBounds);
5719}
5720
5721ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5722 OperationState &result) {
5723 auto &builder = parser.getBuilder();
5724 SMLoc typesLoc;
5725 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5726 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5727 SmallVector<Type, 2> types;
5728 OpAsmParser::UnresolvedOperand maskInfo;
5729 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5730 parser.parseOperand(sourceInfo) ||
5731 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5732 return failure();
5733 ParseResult hasMask = parser.parseOptionalComma();
5734 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5735 return failure();
5736 if (parser.parseOptionalAttrDict(result.attributes) ||
5737 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5738 return failure();
5739 if (types.size() != 2)
5740 return parser.emitError(typesLoc, "requires two types");
5741 auto indexType = builder.getIndexType();
5742 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5743 if (!vectorType)
5744 return parser.emitError(typesLoc, "requires vector type");
5745 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5746 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5747 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5748 auto permMapAttrName =
5749 TransferWriteOp::getPermutationMapAttrName(result.name);
5750 auto permMapAttr = result.attributes.get(permMapAttrName);
5751 AffineMap permMap;
5752 if (!permMapAttr) {
5753 if (shapedType.getRank() <
5754 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5755 return parser.emitError(typesLoc,
5756 "expected a custom permutation_map when "
5757 "rank(source) != rank(destination)");
5758 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5759 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5760 } else {
5761 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5762 }
5763 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5764 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5765 if (!inBoundsAttr) {
5766 result.addAttribute(inBoundsAttrName,
5767 builder.getBoolArrayAttr(
5768 SmallVector<bool>(permMap.getNumResults(), false)));
5769 }
5770 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5771 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5772 parser.resolveOperands(indexInfo, indexType, result.operands))
5773 return failure();
5774 if (hasMask.succeeded()) {
5775 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5776 return parser.emitError(
5777 maskInfo.location, "does not support masks with vector element type");
5778 if (vectorType.getRank() != permMap.getNumResults()) {
5779 return parser.emitError(typesLoc,
5780 "expected the same rank for the vector and the "
5781 "results of the permutation map");
5782 }
5783 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5784 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5785 return failure();
5786 }
5787 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5788 builder.getDenseI32ArrayAttr(
5789 {1, 1, static_cast<int32_t>(indexInfo.size()),
5790 static_cast<int32_t>(hasMask.succeeded())}));
5791 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5792 parser.addTypeToList(shapedType, result.types));
5793}
5794
5795void TransferWriteOp::print(OpAsmPrinter &p) {
5796 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5797 if (getMask())
5798 p << ", " << getMask();
5799 printTransferAttrs(p, *this);
5800 p << " : " << getVectorType() << ", " << getShapedType();
5801}
5802
5803LogicalResult TransferWriteOp::verify() {
5804 // Consistency of elemental types in shape and vector.
5805 ShapedType shapedType = getShapedType();
5806 VectorType vectorType = getVectorType();
5807 VectorType maskType = getMaskType();
5808 auto permutationMap = getPermutationMap();
5809 VectorType inferredMaskType =
5810 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5811 : VectorType();
5812
5813 if (llvm::size(getIndices()) != shapedType.getRank())
5814 return emitOpError("requires ") << shapedType.getRank() << " indices";
5815
5816 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5817 // as the semantics is unclear. This can be revisited later if necessary.
5818 if (hasBroadcastDim())
5819 return emitOpError("should not have broadcast dimensions");
5820
5821 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5822 shapedType, vectorType, maskType,
5823 inferredMaskType, permutationMap, getInBounds())))
5824 return failure();
5825
5826 return verifyPermutationMap(permutationMap,
5827 [&](Twine t) { return emitOpError(t); });
5828}
5829
5830//===----------------------------------------------------------------------===//
5831// TransferWriteOp: MaskableOpInterface methods.
5832//===----------------------------------------------------------------------===//
5833
5834/// Returns the mask type expected by this operation. Mostly used for
5835/// verification purposes.
5836Type TransferWriteOp::getExpectedMaskType() {
5837 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5838}
5839
5840//===----------------------------------------------------------------------===//
5841// TransferWriteOp: VectorTransferOpInterface methods.
5842//===----------------------------------------------------------------------===//
5843Value TransferWriteOp::getVector() { return getOperand(0); }
5844VectorType TransferWriteOp::getVectorType() {
5845 return cast<VectorType>(getValueToStore().getType());
5846}
5847
5848//===----------------------------------------------------------------------===//
5849// TransferWriteOp: fold methods.
5850//===----------------------------------------------------------------------===//
5851/// Fold:
5852/// ```
5853/// %t1 = ...
5854/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5855/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5856/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5857/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5858/// ```
5859///
5860/// into:
5861///
5862/// ```
5863/// %t0
5864/// ```
5865///
5866/// The producer of t1 may or may not be DCE'd depending on whether it is a
5867/// block argument or has side effects.
5868static LogicalResult foldReadInitWrite(TransferWriteOp write,
5869 ArrayRef<Attribute>,
5870 SmallVectorImpl<OpFoldResult> &results) {
5871 // TODO: support 0-d corner case.
5872 if (write.getTransferRank() == 0)
5873 return failure();
5874 auto rankedTensorType =
5875 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5876 // If not operating on tensors, bail.
5877 if (!rankedTensorType)
5878 return failure();
5879 // If no read, bail.
5880 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5881 if (!read)
5882 return failure();
5883 // TODO: support 0-d corner case.
5884 if (read.getTransferRank() == 0)
5885 return failure();
5886 // For now, only accept minor identity. Future: composition is minor identity.
5887 if (!read.getPermutationMap().isMinorIdentity() ||
5888 !write.getPermutationMap().isMinorIdentity())
5889 return failure();
5890 // Bail on mismatching ranks.
5891 if (read.getTransferRank() != write.getTransferRank())
5892 return failure();
5893 // Bail on potential out-of-bounds accesses.
5894 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5895 return failure();
5896 // Masked transfers have padding/select semantics and are not identity folds.
5897 if (read.getMask() || write.getMask())
5898 return failure();
5899 // Tensor types must be the same.
5900 if (read.getBase().getType() != rankedTensorType)
5901 return failure();
5902 // Vector types must be the same.
5903 if (read.getVectorType() != write.getVectorType())
5904 return failure();
5905 // Vector and Tensor shapes must match.
5906 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5907 return failure();
5908 // If any index is nonzero.
5909 auto isNotConstantZero = [](Value v) {
5910 auto cstOp = getConstantIntValue(v);
5911 return !cstOp.has_value() || cstOp.value() != 0;
5912 };
5913 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5914 llvm::any_of(write.getIndices(), isNotConstantZero))
5915 return failure();
5916 // Success.
5917 results.push_back(read.getBase());
5918 return success();
5919}
5920
5921static bool checkSameValueWAR(vector::TransferReadOp read,
5922 vector::TransferWriteOp write) {
5923 return read.getBase() == write.getBase() &&
5924 read.getIndices() == write.getIndices() &&
5925 read.getPermutationMap() == write.getPermutationMap() &&
5926 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5927 !write.getMask();
5928}
5929/// Fold transfer_write write after read:
5930/// ```
5931/// %t0 = ...
5932/// %v = vector.transfer_read %t0[%c0...] :
5933/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5934/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5935/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5936/// ```
5937///
5938/// into:
5939///
5940/// ```
5941/// %t0
5942/// ```
5943static LogicalResult foldWAR(TransferWriteOp write,
5944 SmallVectorImpl<OpFoldResult> &results) {
5945 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5946 return failure();
5947 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5948 if (!read)
5949 return failure();
5950
5951 if (!checkSameValueWAR(read, write))
5952 return failure();
5953 results.push_back(read.getBase());
5954 return success();
5955}
5956
5957LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5958 SmallVectorImpl<OpFoldResult> &results) {
5959 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5960 return success();
5961 if (succeeded(foldWAR(*this, results)))
5962 return success();
5963 if (succeeded(foldTransferInBoundsAttribute(*this)))
5964 return success();
5965 if (succeeded(foldTransferFullMask(*this)))
5966 return success();
5967 if (succeeded(foldSize1TransferPermutationMap(*this)))
5968 return success();
5969 return memref::foldMemRefCast(*this);
5970}
5971
5972//===----------------------------------------------------------------------===//
5973// TransferWriteOp: other methods.
5974//===----------------------------------------------------------------------===//
5975std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5976 return llvm::to_vector<4>(getVectorType().getShape());
5977}
5978
5979void TransferWriteOp::getEffects(
5980 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5981 &effects) {
5982 if (llvm::isa<MemRefType>(getShapedType()))
5983 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5984 SideEffects::DefaultResource::get());
5985}
5986
5987Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5988 if (hasPureTensorSemantics())
5991}
5992
5993namespace {
5994/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5995/// DCE
5996/// ```
5997/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5998/// : vector<1x4xf32>, tensor<4x4xf32>
5999/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
6000/// : vector<1x4xf32>, tensor<4x4xf32>
6001/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
6002/// : vector<1x4xf32>, tensor<4x4xf32>
6003/// ```
6004///
6005/// into:
6006///
6007/// ```
6008/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
6009/// : vector<1x4xf32>, tensor<4x4xf32>
6010/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
6011/// : vector<1x4xf32>, tensor<4x4xf32>
6012/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
6013/// : vector<1x4xf32>, tensor<4x4xf32>
6014/// ```
6015///
6016/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
6017/// any other uses.
6018class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
6019public:
6020 using Base::Base;
6021 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
6022 PatternRewriter &rewriter) const override {
6023 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
6024 return failure();
6025 vector::TransferWriteOp writeToModify = writeOp;
6026
6027 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
6028 while (defWrite) {
6029 if (checkSameValueWAW(writeOp, defWrite)) {
6030 rewriter.modifyOpInPlace(writeToModify, [&]() {
6031 writeToModify.getBaseMutable().assign(defWrite.getBase());
6032 });
6033 return success();
6034 }
6036 cast<VectorTransferOpInterface>(defWrite.getOperation()),
6037 cast<VectorTransferOpInterface>(writeOp.getOperation())))
6038 break;
6039 // If the previous write op doesn't have any other use we an safely look
6040 // at the previous store to see if it can be removed.
6041 if (!defWrite->hasOneUse())
6042 break;
6043 writeToModify = defWrite;
6044 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
6045 }
6046 return failure();
6047 }
6048};
6049
6050/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
6051/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
6052/// overwritten and inserted into another tensor. After this rewrite, the
6053/// operations bufferize in-place since all of them work on the same slice.
6054///
6055/// For example:
6056/// ```mlir
6057/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
6058/// : vector<8x16xf32>, tensor<8x16xf32>
6059/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
6060/// : tensor<8x16xf32> to tensor<?x?xf32>
6061/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6062/// : tensor<?x?xf32> into tensor<27x37xf32>
6063/// ```
6064/// folds to
6065/// ```mlir
6066/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6067/// : tensor<27x37xf32> to tensor<?x?xf32>
6068/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
6069/// : vector<8x16xf32>, tensor<?x?xf32>
6070/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6071/// : tensor<?x?xf32> into tensor<27x37xf32>
6072/// ```
6073struct SwapExtractSliceOfTransferWrite
6074 : public OpRewritePattern<tensor::InsertSliceOp> {
6075public:
6076 using Base::Base;
6077
6078 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
6079 PatternRewriter &rewriter) const override {
6080 if (!insertOp.hasUnitStride())
6081 return failure();
6082 auto extractOp =
6083 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
6084 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
6085 return failure();
6086 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
6087 if (!transferOp || !transferOp->hasOneUse())
6088 return failure();
6089
6090 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
6091 // rank-reducing.
6092 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
6093 return rewriter.notifyMatchFailure(insertOp,
6094 "use-def chain is rank-reducing");
6095 }
6096
6097 // Fail if tensor::ExtractSliceOp has non-zero offset.
6098 if (!extractOp.hasZeroOffset()) {
6099 return rewriter.notifyMatchFailure(insertOp,
6100 "ExtractSliceOp has non-zero offset");
6101 }
6102
6103 // Fail if tensor::TransferWriteOp has non-zero offset.
6104 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
6105 return getConstantIntValue(value) == static_cast<int64_t>(0);
6106 })) {
6107 return rewriter.notifyMatchFailure(insertOp,
6108 "TranferWriteOp has non-zero offset");
6109 }
6110
6111 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
6112 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
6113 return rewriter.notifyMatchFailure(
6114 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
6115 }
6116
6117 for (auto [insertSize, extractSize] :
6118 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
6119 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
6120 return rewriter.notifyMatchFailure(
6121 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
6122 }
6123 }
6124
6125 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
6126 assert(transferOp.getVectorType().hasStaticShape() &&
6127 "expected vector to have a static shape");
6128 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
6129 SmallVector<int64_t> resultShape = applyPermutationMap(
6130 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
6131 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
6132 return rewriter.notifyMatchFailure(
6133 insertOp, "TransferWriteOp may not write the full tensor.");
6134 }
6135
6136 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
6137 // Set all in_bounds to false and let the folder infer them.
6138 SmallVector<bool> newInBounds(vectorShape.size(), false);
6139 auto newExtractOp = tensor::ExtractSliceOp::create(
6140 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
6141 insertOp.getDest(), insertOp.getMixedOffsets(),
6142 insertOp.getMixedSizes(), insertOp.getMixedStrides());
6143 auto newTransferWriteOp = TransferWriteOp::create(
6144 rewriter, transferOp.getLoc(), transferOp.getVector(),
6145 newExtractOp.getResult(), transferOp.getIndices(),
6146 transferOp.getPermutationMapAttr(),
6147 rewriter.getBoolArrayAttr(newInBounds));
6148 rewriter.modifyOpInPlace(insertOp, [&]() {
6149 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
6150 });
6151 return success();
6152 }
6153};
6154
6155} // namespace
6156
6157void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
6158 MLIRContext *context) {
6159 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
6160}
6161
6162FailureOr<std::optional<SmallVector<Value>>>
6163TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
6164 if (!hasPureBufferSemantics())
6165 return failure();
6167 ValueRange());
6168}
6169
6170//===----------------------------------------------------------------------===//
6171// LoadOp
6172//===----------------------------------------------------------------------===//
6173
6174static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
6175 VectorType vecTy,
6176 MemRefType memRefTy) {
6177 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
6178 // need any strides limitations.
6179 if (!vecTy.isScalable() &&
6180 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
6181 return success();
6182
6183 if (!memRefTy.isLastDimUnitStride())
6184 return op->emitOpError("most minor memref dim must have unit stride");
6185 return success();
6186}
6187
6188LogicalResult vector::LoadOp::verify() {
6189 VectorType resVecTy = getVectorType();
6190 MemRefType memRefTy = getMemRefType();
6191
6192 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
6193 return failure();
6194
6195 if (memRefTy.getRank() < resVecTy.getRank())
6196 return emitOpError(
6197 "destination memref has lower rank than the result vector");
6198
6199 // Checks for vector memrefs.
6200 Type memElemTy = memRefTy.getElementType();
6201 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6202 if (memVecTy != resVecTy)
6203 return emitOpError("base memref and result vector types should match");
6204 memElemTy = memVecTy.getElementType();
6205 }
6206
6207 if (resVecTy.getElementType() != memElemTy)
6208 return emitOpError("base and result element types should match");
6209 if (llvm::size(getIndices()) != memRefTy.getRank())
6210 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6211 return success();
6212}
6213
6214OpFoldResult LoadOp::fold(FoldAdaptor) {
6215 if (succeeded(memref::foldMemRefCast(*this)))
6216 return getResult();
6217 return OpFoldResult();
6218}
6219
6220std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6221 return llvm::to_vector<4>(getVectorType().getShape());
6222}
6223
6224FailureOr<std::optional<SmallVector<Value>>>
6225LoadOp::bubbleDownCasts(OpBuilder &builder) {
6227 getResult());
6228}
6229
6230//===----------------------------------------------------------------------===//
6231// StoreOp
6232//===----------------------------------------------------------------------===//
6233
6234LogicalResult vector::StoreOp::verify() {
6235 VectorType valueVecTy = getVectorType();
6236 MemRefType memRefTy = getMemRefType();
6237
6238 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
6239 return failure();
6240
6241 if (memRefTy.getRank() < valueVecTy.getRank())
6242 return emitOpError("source memref has lower rank than the vector to store");
6243
6244 // Checks for vector memrefs.
6245 Type memElemTy = memRefTy.getElementType();
6246 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6247 if (memVecTy != valueVecTy)
6248 return emitOpError(
6249 "base memref and valueToStore vector types should match");
6250 memElemTy = memVecTy.getElementType();
6251 }
6252
6253 if (valueVecTy.getElementType() != memElemTy)
6254 return emitOpError("base and valueToStore element type should match");
6255 if (llvm::size(getIndices()) != memRefTy.getRank())
6256 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6257 return success();
6258}
6259
6260LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6261 SmallVectorImpl<OpFoldResult> &results) {
6262 return memref::foldMemRefCast(*this);
6263}
6264
6265std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6266 return llvm::to_vector<4>(getVectorType().getShape());
6267}
6268
6269FailureOr<std::optional<SmallVector<Value>>>
6270StoreOp::bubbleDownCasts(OpBuilder &builder) {
6272 ValueRange());
6273}
6274
6275//===----------------------------------------------------------------------===//
6276// MaskedLoadOp
6277//===----------------------------------------------------------------------===//
6278
6279LogicalResult MaskedLoadOp::verify() {
6280 VectorType maskVType = getMaskVectorType();
6281 VectorType passVType = getPassThruVectorType();
6282 VectorType resVType = getVectorType();
6283 MemRefType memType = getMemRefType();
6284
6285 if (failed(
6286 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6287 return failure();
6288 if (llvm::size(getIndices()) != memType.getRank())
6289 return emitOpError("requires ") << memType.getRank() << " indices";
6290 if (resVType.getShape() != maskVType.getShape())
6291 return emitOpError("expected result shape to match mask shape");
6292 if (resVType != passVType)
6293 return emitOpError("expected pass_thru of same type as result type");
6294 return success();
6295}
6296
6297namespace {
6298class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
6299public:
6300 using Base::Base;
6301 LogicalResult matchAndRewrite(MaskedLoadOp load,
6302 PatternRewriter &rewriter) const override {
6303 switch (getMaskFormat(load.getMask())) {
6305 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6306 load, load.getType(), load.getBase(), load.getIndices());
6307 return success();
6309 rewriter.replaceOp(load, load.getPassThru());
6310 return success();
6312 return failure();
6313 }
6314 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
6315 }
6316};
6317} // namespace
6318
6319void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6320 MLIRContext *context) {
6321 results.add<MaskedLoadFolder>(context);
6322}
6323
6324OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6325 if (succeeded(memref::foldMemRefCast(*this)))
6326 return getResult();
6327 return OpFoldResult();
6328}
6329
6330FailureOr<std::optional<SmallVector<Value>>>
6331MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6333 getResult());
6334}
6335
6336//===----------------------------------------------------------------------===//
6337// MaskedStoreOp
6338//===----------------------------------------------------------------------===//
6339
6340LogicalResult MaskedStoreOp::verify() {
6341 VectorType maskVType = getMaskVectorType();
6342 VectorType valueVType = getVectorType();
6343 MemRefType memType = getMemRefType();
6344
6345 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6346 "valueToStore")))
6347 return failure();
6348 if (llvm::size(getIndices()) != memType.getRank())
6349 return emitOpError("requires ") << memType.getRank() << " indices";
6350 if (valueVType.getShape() != maskVType.getShape())
6351 return emitOpError("expected valueToStore shape to match mask shape");
6352 return success();
6353}
6354
6355namespace {
6356class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
6357public:
6358 using Base::Base;
6359 LogicalResult matchAndRewrite(MaskedStoreOp store,
6360 PatternRewriter &rewriter) const override {
6361 switch (getMaskFormat(store.getMask())) {
6363 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6364 store, store.getValueToStore(), store.getBase(), store.getIndices());
6365 return success();
6367 rewriter.eraseOp(store);
6368 return success();
6370 return failure();
6371 }
6372 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
6373 }
6374};
6375} // namespace
6376
6377void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6378 MLIRContext *context) {
6379 results.add<MaskedStoreFolder>(context);
6380}
6381
6382LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6383 SmallVectorImpl<OpFoldResult> &results) {
6384 return memref::foldMemRefCast(*this);
6385}
6386
6387FailureOr<std::optional<SmallVector<Value>>>
6388MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6390 ValueRange());
6391}
6392
6393//===----------------------------------------------------------------------===//
6394// GatherOp
6395//===----------------------------------------------------------------------===//
6396
6397LogicalResult GatherOp::verify() {
6398 VectorType indVType = getIndexVectorType();
6399 VectorType maskVType = getMaskVectorType();
6400 VectorType resVType = getVectorType();
6401 ShapedType baseType = getBaseType();
6402
6403 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6404 return emitOpError("requires base to be a memref or ranked tensor type");
6405
6406 if (failed(
6407 verifyElementTypesMatch(*this, baseType, resVType, "base", "result")))
6408 return failure();
6409 if (llvm::size(getOffsets()) != baseType.getRank())
6410 return emitOpError("requires ") << baseType.getRank() << " indices";
6411 if (resVType.getShape() != indVType.getShape())
6412 return emitOpError("expected result dim to match indices dim");
6413 if (resVType.getShape() != maskVType.getShape())
6414 return emitOpError("expected result dim to match mask dim");
6415 if (resVType != getPassThruVectorType())
6416 return emitOpError("expected pass_thru of same type as result type");
6417 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6418 return emitOpError(
6419 "alignment is only supported for memref bases, not tensor bases");
6420 }
6421 return success();
6422}
6423
6424// MaskableOpInterface methods.
6425
6426/// Returns the mask type expected by this operation. Mostly used for
6427/// verification purposes. It requires the operation to be vectorized."
6428Type GatherOp::getExpectedMaskType() {
6429 auto vecType = this->getIndexVectorType();
6430 return VectorType::get(vecType.getShape(),
6431 IntegerType::get(vecType.getContext(), /*width=*/1),
6432 vecType.getScalableDims());
6433}
6434
6435std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6436 return llvm::to_vector<4>(getVectorType().getShape());
6437}
6438
6439/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
6440static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6441 auto vecType = dyn_cast<VectorType>(indexVec.getType());
6442 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6443 return failure();
6444
6445 if (indexVec.getDefiningOp<StepOp>())
6446 return success();
6447
6448 DenseIntElementsAttr elements;
6449 if (!matchPattern(indexVec, m_Constant(&elements)))
6450 return failure();
6451
6452 return success(
6453 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6454}
6455
6456namespace {
6457class GatherFolder final : public OpRewritePattern<GatherOp> {
6458public:
6459 using Base::Base;
6460 LogicalResult matchAndRewrite(GatherOp gather,
6461 PatternRewriter &rewriter) const override {
6462 switch (getMaskFormat(gather.getMask())) {
6464 return failure(); // no unmasked equivalent
6466 rewriter.replaceOp(gather, gather.getPassThru());
6467 return success();
6469 return failure();
6470 }
6471 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6472 }
6473};
6474
6475/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6476/// maskedload. Only 1D fixed vectors are supported for now.
6477class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6478public:
6479 using Base::Base;
6480 LogicalResult matchAndRewrite(GatherOp op,
6481 PatternRewriter &rewriter) const override {
6482 if (!isa<MemRefType>(op.getBase().getType()))
6483 return rewriter.notifyMatchFailure(op, "base must be of memref type");
6484
6485 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6486 return failure();
6487
6488 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6489 op.getOffsets(), op.getMask(),
6490 op.getPassThru());
6491 return success();
6492 }
6493};
6494} // namespace
6495
6496void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6497 MLIRContext *context) {
6498 results.add<GatherFolder, FoldContiguousGather>(context);
6499}
6500
6501FailureOr<std::optional<SmallVector<Value>>>
6502GatherOp::bubbleDownCasts(OpBuilder &builder) {
6504 getResult());
6505}
6506
6507//===----------------------------------------------------------------------===//
6508// ScatterOp
6509//===----------------------------------------------------------------------===//
6510
6511LogicalResult ScatterOp::verify() {
6512 VectorType indVType = getIndexVectorType();
6513 VectorType maskVType = getMaskVectorType();
6514 VectorType valueVType = getVectorType();
6515 ShapedType baseType = getBaseType();
6516
6517 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6518 return emitOpError("requires base to be a memref or ranked tensor type");
6519
6520 if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
6521 "valueToStore")))
6522 return failure();
6523 if (llvm::size(getOffsets()) != baseType.getRank())
6524 return emitOpError("requires ") << baseType.getRank() << " indices";
6525 if (valueVType.getShape() != indVType.getShape())
6526 return emitOpError("expected valueToStore dim to match indices dim");
6527 if (valueVType.getShape() != maskVType.getShape())
6528 return emitOpError("expected valueToStore dim to match mask dim");
6529 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6530 return emitOpError(
6531 "alignment is only supported for memref bases, not tensor bases");
6532 }
6533 return success();
6534}
6535namespace {
6536class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6537public:
6538 using Base::Base;
6539 LogicalResult matchAndRewrite(ScatterOp scatter,
6540 PatternRewriter &rewriter) const override {
6541 ShapedType baseType = scatter.getBaseType();
6542 bool isMemRef = isa<MemRefType>(baseType);
6543 if (!isMemRef && !isa<RankedTensorType>(baseType))
6544 return failure();
6545
6546 // Memrefs have no result, so an all-false mask can simply erase the op.
6547 // Tensors carry the updated value, so we must replace uses with the
6548 // original base tensor instead of erasing.
6549 switch (getMaskFormat(scatter.getMask())) {
6551 return failure(); // no unmasked equivalent
6553 if (isMemRef)
6554 rewriter.eraseOp(scatter);
6555 else
6556 rewriter.replaceOp(scatter, scatter.getBase());
6557 return success();
6559 return failure();
6560 }
6561 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6562 }
6563};
6564
6565/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6566/// maskedstore. Only 1D fixed vectors are supported for now.
6567class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6568public:
6569 using Base::Base;
6570 LogicalResult matchAndRewrite(ScatterOp op,
6571 PatternRewriter &rewriter) const override {
6572 // Fold only for memrefs: the replacement uses maskedstore, which does not
6573 // support tensor bases. Tensor cases intentionally bail out.
6574 if (!isa<MemRefType>(op.getBase().getType()))
6575 return failure();
6576
6577 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6578 return failure();
6579
6580 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6581 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6582 return success();
6583 }
6584};
6585} // namespace
6586
6587void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6588 MLIRContext *context) {
6589 results.add<ScatterFolder, FoldContiguousScatter>(context);
6590}
6591
6592FailureOr<std::optional<SmallVector<Value>>>
6593ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6595 ValueRange());
6596}
6597
6598//===----------------------------------------------------------------------===//
6599// ExpandLoadOp
6600//===----------------------------------------------------------------------===//
6601
6602LogicalResult ExpandLoadOp::verify() {
6603 VectorType maskVType = getMaskVectorType();
6604 VectorType passVType = getPassThruVectorType();
6605 VectorType resVType = getVectorType();
6606 MemRefType memType = getMemRefType();
6607
6608 if (failed(
6609 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6610 return failure();
6611 if (llvm::size(getIndices()) != memType.getRank())
6612 return emitOpError("requires ") << memType.getRank() << " indices";
6613 if (resVType.getShape() != maskVType.getShape())
6614 return emitOpError("expected result shape to match mask shape");
6615 if (resVType != passVType)
6616 return emitOpError("expected pass_thru of same type as result type");
6617 return success();
6618}
6619
6620namespace {
6621class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6622public:
6623 using Base::Base;
6624 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6625 PatternRewriter &rewriter) const override {
6626 switch (getMaskFormat(expand.getMask())) {
6628 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6629 expand, expand.getType(), expand.getBase(), expand.getIndices());
6630 return success();
6632 rewriter.replaceOp(expand, expand.getPassThru());
6633 return success();
6635 return failure();
6636 }
6637 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6638 }
6639};
6640} // namespace
6641
6642void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6643 MLIRContext *context) {
6644 results.add<ExpandLoadFolder>(context);
6645}
6646
6647FailureOr<std::optional<SmallVector<Value>>>
6648ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6650 getResult());
6651}
6652
6653//===----------------------------------------------------------------------===//
6654// CompressStoreOp
6655//===----------------------------------------------------------------------===//
6656
6657LogicalResult CompressStoreOp::verify() {
6658 VectorType maskVType = getMaskVectorType();
6659 VectorType valueVType = getVectorType();
6660 MemRefType memType = getMemRefType();
6661
6662 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6663 "valueToStore")))
6664 return failure();
6665 if (llvm::size(getIndices()) != memType.getRank())
6666 return emitOpError("requires ") << memType.getRank() << " indices";
6667 if (valueVType.getShape() != maskVType.getShape())
6668 return emitOpError("expected valueToStore shape to match mask shape");
6669 return success();
6670}
6671
6672namespace {
6673class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6674public:
6675 using Base::Base;
6676 LogicalResult matchAndRewrite(CompressStoreOp compress,
6677 PatternRewriter &rewriter) const override {
6678 switch (getMaskFormat(compress.getMask())) {
6680 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6681 compress, compress.getValueToStore(), compress.getBase(),
6682 compress.getIndices());
6683 return success();
6685 rewriter.eraseOp(compress);
6686 return success();
6688 return failure();
6689 }
6690 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6691 }
6692};
6693} // namespace
6694
6695void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6696 MLIRContext *context) {
6697 results.add<CompressStoreFolder>(context);
6698}
6699
6700FailureOr<std::optional<SmallVector<Value>>>
6701CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6703 ValueRange());
6704}
6705
6706//===----------------------------------------------------------------------===//
6707// ShapeCastOp
6708//===----------------------------------------------------------------------===//
6709
6710void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6711 SetIntRangeFn setResultRanges) {
6712 setResultRanges(getResult(), argRanges.front());
6713}
6714
6715std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6716 return llvm::to_vector<4>(getResultVectorType().getShape());
6717}
6718
6719LogicalResult ShapeCastOp::verify() {
6720
6721 VectorType sourceType = getSourceVectorType();
6722 VectorType resultType = getResultVectorType();
6723
6724 // Check that element type is preserved
6725 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
6726 "result")))
6727 return failure();
6728
6729 // Check that number of elements is preserved
6730 int64_t sourceNElms = sourceType.getNumElements();
6731 int64_t resultNElms = resultType.getNumElements();
6732 if (sourceNElms != resultNElms) {
6733 return emitOpError() << "has different number of elements at source ("
6734 << sourceNElms << ") and result (" << resultNElms
6735 << ")";
6736 }
6737
6738 // Check that (non-)scalability is preserved
6739 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6740 int64_t resultNScalableDims = resultType.getNumScalableDims();
6741 if (sourceNScalableDims != resultNScalableDims)
6742 return emitOpError() << "has different number of scalable dims at source ("
6743 << sourceNScalableDims << ") and result ("
6744 << resultNScalableDims << ")";
6745
6746 return success();
6747}
6748
6749/// Return true if `transpose` does not permute a pair of non-unit dims.
6750/// By `order preserving` we mean that the flattened versions of the input and
6751/// output vectors are (numerically) identical. In other words `transpose` is
6752/// effectively a shape cast.
6753static bool isOrderPreserving(TransposeOp transpose) {
6754 ArrayRef<int64_t> permutation = transpose.getPermutation();
6755 VectorType sourceType = transpose.getSourceVectorType();
6756 ArrayRef<int64_t> inShape = sourceType.getShape();
6757 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6758 auto isNonScalableUnitDim = [&](int64_t dim) {
6759 return inShape[dim] == 1 && !inDimIsScalable[dim];
6760 };
6761 int64_t current = 0;
6762 for (auto p : permutation) {
6763 if (!isNonScalableUnitDim(p)) {
6764 if (p < current) {
6765 return false;
6766 }
6767 current = p;
6768 }
6769 }
6770 return true;
6771}
6772
6773OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6774
6775 VectorType resultType = getType();
6776
6777 // No-op shape cast.
6778 if (getSource().getType() == resultType)
6779 return getSource();
6780
6781 // shape_cast(shape_cast(x)) -> shape_cast(x)
6782 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6783 setOperand(precedingShapeCast.getSource());
6784 return getResult();
6785 }
6786
6787 // shape_cast(transpose(x)) -> shape_cast(x)
6788 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6789 if (isOrderPreserving(transpose)) {
6790 setOperand(transpose.getVector());
6791 return getResult();
6792 }
6793 return {};
6794 }
6795
6796 // Y = shape_cast(broadcast(X))
6797 // -> X, if X and Y have same type
6798 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6799 if (bcastOp.getSourceType() == resultType)
6800 return bcastOp.getSource();
6801 }
6802
6803 // shape_cast(constant) -> constant
6804 if (auto denseAttr =
6805 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6806 return denseAttr.reshape(getType());
6807
6808 // shape_cast(poison) -> poison
6809 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
6810 return ub::PoisonAttr::get(getContext());
6811
6812 return {};
6813}
6814
6815namespace {
6816
6817/// Helper function that computes a new vector type based on the input vector
6818/// type by removing the trailing one dims:
6819///
6820/// vector<4x1x1xi1> --> vector<4x1xi1>
6821///
6822static VectorType trimTrailingOneDims(VectorType oldType) {
6823 ArrayRef<int64_t> oldShape = oldType.getShape();
6824 ArrayRef<int64_t> newShape = oldShape;
6825
6826 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6827 ArrayRef<bool> newScalableDims = oldScalableDims;
6828
6829 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6830 newShape = newShape.drop_back(1);
6831 newScalableDims = newScalableDims.drop_back(1);
6832 }
6833
6834 // Make sure we have at least 1 dimension.
6835 // TODO: Add support for 0-D vectors.
6836 if (newShape.empty()) {
6837 newShape = oldShape.take_back();
6838 newScalableDims = oldScalableDims.take_back();
6839 }
6840
6841 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6842}
6843
6844/// Folds qualifying shape_cast(create_mask) into a new create_mask
6845///
6846/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6847/// dimension. If the input vector comes from `vector.create_mask` for which
6848/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6849/// to fold shape_cast into create_mask.
6850///
6851/// BEFORE:
6852/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6853/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6854/// AFTER:
6855/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6856class ShapeCastCreateMaskFolderTrailingOneDim final
6857 : public OpRewritePattern<ShapeCastOp> {
6858public:
6859 using Base::Base;
6860
6861 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6862 PatternRewriter &rewriter) const override {
6863 Value shapeOpSrc = shapeOp->getOperand(0);
6864 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6865 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6866 if (!createMaskOp && !constantMaskOp)
6867 return failure();
6868
6869 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6870 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6871
6872 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6873 if (newVecType != shapeOpResTy)
6874 return failure();
6875
6876 auto numDimsToDrop =
6877 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6878
6879 // No unit dims to drop
6880 if (!numDimsToDrop)
6881 return failure();
6882
6883 if (createMaskOp) {
6884 auto maskOperands = createMaskOp.getOperands();
6885 auto numMaskOperands = maskOperands.size();
6886
6887 // Check every mask dim size to see whether it can be dropped
6888 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6889 --i) {
6890 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6891 if (!constant || (constant.value() != 1))
6892 return failure();
6893 }
6894 SmallVector<Value> newMaskOperands =
6895 maskOperands.drop_back(numDimsToDrop);
6896
6897 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6898 newMaskOperands);
6899 return success();
6900 }
6901
6902 if (constantMaskOp) {
6903 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6904 auto numMaskOperands = maskDimSizes.size();
6905
6906 // Check every mask dim size to see whether it can be dropped
6907 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6908 --i) {
6909 if (maskDimSizes[i] != 1)
6910 return failure();
6911 }
6912
6913 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6914 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6915 newMaskOperands);
6916 return success();
6917 }
6918
6919 return failure();
6920 }
6921};
6922
6923// vector.broadcast has two distinct semantic modes: duplication across leading
6924// dimensions, and stretching across inner dimensions. This helper returns the
6925// product of the inner-dimension stretching factors.
6926int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
6927 ArrayRef<int64_t> dstShape) {
6928 int stretchingFactor = 1;
6929 int numLeadingDims = dstShape.size() - srcShape.size();
6930 for (int i = 0, e = srcShape.size(); i < e; i++) {
6931 int64_t dstDim = dstShape[numLeadingDims + i];
6932 if (srcShape[i] == 1 && dstDim != 1) {
6933 stretchingFactor *= dstDim;
6934 }
6935 }
6936 return stretchingFactor;
6937}
6938
6939/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
6940class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6941public:
6942 using Base::Base;
6943
6944 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6945 PatternRewriter &rewriter) const override {
6946 auto broadcastOp =
6947 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6948 if (!broadcastOp)
6949 return failure();
6950
6951 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6952 bool srcIsScalar = !srcVectorType;
6953
6954 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6955 // Example
6956 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6957 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6958 // to
6959 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6960 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6961 ArrayRef<int64_t> dstShape = dstVectorType.getShape();
6962 ArrayRef<int64_t> srcShape =
6963 srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
6964 ArrayRef<int64_t> broadcastShape =
6965 broadcastOp.getResultVectorType().getShape();
6966
6967 if (!srcIsScalar) {
6968 if (isBroadcastableTo(srcVectorType, dstVectorType) !=
6969 BroadcastableToResult::Success) {
6970 return failure();
6971 }
6972 // Avoid folding if this would result in switching between the two
6973 // distinct semantic modes of vector.broadcast (duplication vs
6974 // stretching). See https://github.com/llvm/llvm-project/issues/190614.
6975 // This is detected by a change in the stretching factor. However if the
6976 // source has a single element, there is no ambiguity.
6977 if (srcVectorType.getNumElements() != 1) {
6978 if (getBroadcastStretchingFactor(srcShape, dstShape) !=
6979 getBroadcastStretchingFactor(srcShape, broadcastShape)) {
6980 return failure();
6981 }
6982 }
6983 }
6984
6985 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shapeCastOp, dstVectorType,
6986 broadcastOp.getSource());
6987 return success();
6988 }
6989};
6990
6991/// Pattern to rewrite Y = ShapeCast(FromElements(X)) as Y = FromElements(X)
6992///
6993/// BEFORE:
6994/// %1 = vector.from_elements %c1, %c2, %c3 : vector<3xf32>
6995/// %2 = vector.shape_cast %1 : vector<3xf32> to vector<1x3xf32>
6996/// AFTER:
6997/// %2 = vector.from_elements %c1, %c2, %c3 : vector<1x3xf32>
6998///
6999/// Note: this transformation is implemented as an OpRewritePattern, not as a
7000/// fold, because we have to create new op FromElementsOp with updated result
7001/// type. This cannot be done with a fold, because fold cannot create new ops
7002/// and the existing FromElementsOp result type differs from the ShapeCastOp
7003/// result type. Mutating the FromElementsOp (not root op) would violate the
7004/// fold contract and break other users.
7005class FoldShapeCastOfFromElements final : public OpRewritePattern<ShapeCastOp> {
7006public:
7007 using Base::Base;
7008
7009 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
7010 PatternRewriter &rewriter) const override {
7011 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
7012 if (!fromElements)
7013 return failure();
7014
7015 rewriter.replaceOpWithNewOp<FromElementsOp>(
7016 shapeCastOp, shapeCastOp.getResultVectorType(),
7017 fromElements.getElements());
7018 return success();
7019 }
7020};
7021
7022} // namespace
7023
7024void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
7025 MLIRContext *context) {
7026 results.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
7027 FoldShapeCastOfFromElements>(context);
7028}
7029
7030//===----------------------------------------------------------------------===//
7031// VectorBitCastOp
7032//===----------------------------------------------------------------------===//
7033
7034LogicalResult BitCastOp::verify() {
7035 auto sourceVectorType = getSourceVectorType();
7036 auto resultVectorType = getResultVectorType();
7037
7038 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
7039 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
7040 return emitOpError("dimension size mismatch at: ") << i;
7041 }
7042
7043 DataLayout dataLayout = DataLayout::closest(*this);
7044 auto sourceElementBits =
7045 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
7046 auto resultElementBits =
7047 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
7048
7049 if (sourceVectorType.getRank() == 0) {
7050 if (sourceElementBits != resultElementBits)
7051 return emitOpError("source/result bitwidth of the 0-D vector element "
7052 "types must be equal");
7053 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
7054 resultElementBits * resultVectorType.getShape().back()) {
7055 return emitOpError(
7056 "source/result bitwidth of the minor 1-D vectors must be equal");
7057 }
7058
7059 return success();
7060}
7061
7062OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
7063 // Nop cast.
7064 if (getSource().getType() == getResult().getType())
7065 return getSource();
7066
7067 // Canceling bitcasts.
7068 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
7069 if (getResult().getType() == otherOp.getSource().getType())
7070 return otherOp.getSource();
7071
7072 setOperand(otherOp.getSource());
7073 return getResult();
7074 }
7075
7076 Attribute sourceConstant = adaptor.getSource();
7077 if (!sourceConstant)
7078 return {};
7079
7080 Type srcElemType = getSourceVectorType().getElementType();
7081 Type dstElemType = getResultVectorType().getElementType();
7082
7083 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
7084 if (floatPack.isSplat()) {
7085 auto splat = floatPack.getSplatValue<FloatAttr>();
7086
7087 // Casting fp16 into fp32.
7088 if (srcElemType.isF16() && dstElemType.isF32()) {
7089 uint32_t bits = static_cast<uint32_t>(
7090 splat.getValue().bitcastToAPInt().getZExtValue());
7091 // Duplicate the 16-bit pattern.
7092 bits = (bits << 16) | (bits & 0xffff);
7093 APInt intBits(32, bits);
7094 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
7095 return DenseElementsAttr::get(getResultVectorType(), floatBits);
7096 }
7097 }
7098 }
7099
7100 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
7101 if (intPack.isSplat()) {
7102 auto splat = intPack.getSplatValue<IntegerAttr>();
7103
7104 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.isIntOrFloat()) {
7105 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
7106 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
7107
7108 // Casting to a larger integer bit width.
7109 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
7110 APInt intBits = splat.getValue().zext(dstBitWidth);
7111
7112 // Duplicate the lower width element.
7113 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
7114 intBits = (intBits << srcBitWidth) | intBits;
7115 return DenseElementsAttr::get(getResultVectorType(), intBits);
7116 }
7117 }
7118 }
7119 }
7120
7121 return {};
7122}
7123
7124std::optional<SmallVector<int64_t, 4>> BitCastOp::getShapeForUnroll() {
7125 return llvm::to_vector<4>(getResultVectorType().getShape());
7126}
7127
7128//===----------------------------------------------------------------------===//
7129// TypeCastOp
7130//===----------------------------------------------------------------------===//
7131
7132static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
7133 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
7134 SmallVector<int64_t, 8> res(memRefType.getShape());
7135 if (vectorType)
7136 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
7137 return res;
7138}
7139
7140/// Build the canonical memRefType with a single vector.
7141/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
7142void TypeCastOp::build(OpBuilder &builder, OperationState &result,
7143 Value source) {
7144 result.addOperands(source);
7145 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
7146 VectorType vectorType =
7147 VectorType::get(extractShape(memRefType),
7149 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
7150 memRefType.getMemorySpace()));
7151}
7152
7153LogicalResult TypeCastOp::verify() {
7154 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
7155 if (!canonicalType.getLayout().isIdentity())
7156 return emitOpError("expects operand to be a memref with identity layout");
7157 if (!getResultMemRefType().getLayout().isIdentity())
7158 return emitOpError("expects result to be a memref with identity layout");
7159 if (getResultMemRefType().getMemorySpace() !=
7160 getMemRefType().getMemorySpace())
7161 return emitOpError("expects result in same memory space");
7162
7163 auto sourceType = getMemRefType();
7164 auto resultType = getResultMemRefType();
7165 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
7167 return emitOpError(
7168 "expects result and operand with same underlying scalar type: ")
7169 << resultType;
7170 if (extractShape(sourceType) != extractShape(resultType))
7171 return emitOpError(
7172 "expects concatenated result and operand shapes to be equal: ")
7173 << resultType;
7174 return success();
7175}
7176
7177//===----------------------------------------------------------------------===//
7178// TransposeOp
7179//===----------------------------------------------------------------------===//
7180
7181void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
7182 Value vector, ArrayRef<int64_t> permutation) {
7183 VectorType vt = llvm::cast<VectorType>(vector.getType());
7184 SmallVector<int64_t, 4> transposedShape(vt.getRank());
7185 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
7186 for (unsigned i = 0; i < permutation.size(); ++i) {
7187 transposedShape[i] = vt.getShape()[permutation[i]];
7188 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
7189 }
7190
7191 result.addOperands(vector);
7192 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
7193 transposedScalableDims));
7194 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
7195 builder.getDenseI64ArrayAttr(permutation));
7196}
7197
7198OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
7199 // Eliminate splat constant transpose ops.
7200 if (auto splat =
7201 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
7202 return splat.reshape(getResultVectorType());
7203
7204 // Eliminate poison transpose ops.
7205 if (matchPattern(adaptor.getVector(), ub::m_Poison()))
7206 return ub::PoisonAttr::get(getContext());
7207
7208 // Eliminate identity transposes, and more generally any transposes that
7209 // preserves the shape without permuting elements.
7210 //
7211 // Examples of what to fold:
7212 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
7213 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
7214 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
7215 //
7216 // Example of what NOT to fold:
7217 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
7218 //
7219 if (getSourceVectorType() == getResultVectorType() &&
7220 isOrderPreserving(*this))
7221 return getVector();
7222
7223 return {};
7224}
7225
7226LogicalResult vector::TransposeOp::verify() {
7227 VectorType vectorType = getSourceVectorType();
7228 VectorType resultType = getResultVectorType();
7229 int64_t rank = resultType.getRank();
7230 if (vectorType.getRank() != rank)
7231 return emitOpError("vector result rank mismatch: ") << rank;
7232 // Verify transposition array.
7233 ArrayRef<int64_t> perm = getPermutation();
7234 int64_t size = perm.size();
7235 if (rank != size)
7236 return emitOpError("transposition length mismatch: ") << size;
7237 SmallVector<bool, 8> seen(rank, false);
7238 for (const auto &ta : llvm::enumerate(perm)) {
7239 if (ta.value() < 0 || ta.value() >= rank)
7240 return emitOpError("transposition index out of range: ") << ta.value();
7241 if (seen[ta.value()])
7242 return emitOpError("duplicate position index: ") << ta.value();
7243 seen[ta.value()] = true;
7244 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
7245 return emitOpError("dimension size mismatch at: ") << ta.value();
7246 }
7247 return success();
7248}
7249
7250std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
7251 return llvm::to_vector<4>(getResultVectorType().getShape());
7252}
7253
7254void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7255 SetIntRangeFn setResultRanges) {
7256 setResultRanges(getResult(), argRanges.front());
7257}
7258
7259namespace {
7260
7261// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
7262class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
7263public:
7264 using Base::Base;
7265
7266 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7267 PatternRewriter &rewriter) const override {
7268 // Composes two permutations: result[i] = permutation1[permutation2[i]].
7269 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7270 ArrayRef<int64_t> permutation2) {
7271 SmallVector<int64_t, 4> result;
7272 for (auto index : permutation2)
7273 result.push_back(permutation1[index]);
7274 return result;
7275 };
7276
7277 // Return if the input of 'transposeOp' is not defined by another transpose.
7278 vector::TransposeOp parentTransposeOp =
7279 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7280 if (!parentTransposeOp)
7281 return failure();
7282
7283 SmallVector<int64_t, 4> permutation = composePermutations(
7284 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7285 // Replace 'transposeOp' with a new transpose operation.
7286 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
7287 transposeOp, transposeOp.getResult().getType(),
7288 parentTransposeOp.getVector(), permutation);
7289 return success();
7290 }
7291};
7292
7293/// Replace transpose(splat-like(v)) with broadcast(v)
7294class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
7295public:
7296 using Base::Base;
7297
7298 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7299 PatternRewriter &rewriter) const override {
7300 Value splat = getScalarSplatSource(transposeOp.getVector());
7301 if (!splat)
7302 return failure();
7303
7304 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
7305 transposeOp, transposeOp.getResultVectorType(), splat);
7306 return success();
7307 }
7308};
7309
7310/// Folds transpose(create_mask) into a new transposed create_mask.
7311class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
7312public:
7313 using Base::Base;
7314
7315 LogicalResult matchAndRewrite(TransposeOp transpOp,
7316 PatternRewriter &rewriter) const override {
7317 Value transposeSrc = transpOp.getVector();
7318 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
7319 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
7320 if (!createMaskOp && !constantMaskOp)
7321 return failure();
7322
7323 // Get the transpose permutation and apply it to the vector.create_mask or
7324 // vector.constant_mask operands.
7325 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7326
7327 if (createMaskOp) {
7328 auto maskOperands = createMaskOp.getOperands();
7329 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7330 applyPermutationToVector(newOperands, permutation);
7331
7332 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
7333 transpOp, transpOp.getResultVectorType(), newOperands);
7334 return success();
7335 }
7336
7337 // ConstantMaskOp case.
7338 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7339 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
7340
7341 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
7342 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7343 return success();
7344 }
7345};
7346
7347/// Folds transpose(shape_cast) into a new shape_cast.
7348class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
7349public:
7350 using Base::Base;
7351
7352 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7353 PatternRewriter &rewriter) const override {
7354 auto shapeCastOp =
7355 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7356 if (!shapeCastOp)
7357 return failure();
7358 if (!isOrderPreserving(transposeOp))
7359 return failure();
7360
7361 VectorType resultType = transposeOp.getType();
7362
7363 // We don't need to check isValidShapeCast at this point, because it is
7364 // guaranteed that merging the transpose into the the shape_cast is a valid
7365 // shape_cast, because the transpose just inserts/removes ones.
7366
7367 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
7368 shapeCastOp.getSource());
7369 return success();
7370 }
7371};
7372
7373/// Folds transpose(from_elements(...)) into a new from_elements with permuted
7374/// operands matching the transposed shape.
7375///
7376/// Example:
7377///
7378/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
7379/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
7380/// vector<3x2xi32>
7381///
7382/// becomes ->
7383///
7384/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
7385/// vector<3x2xi32>
7386///
7387class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
7388public:
7389 using Base::Base;
7390 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7391 PatternRewriter &rewriter) const override {
7392 auto fromElementsOp =
7393 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7394 if (!fromElementsOp)
7395 return failure();
7396
7397 VectorType srcTy = fromElementsOp.getDest().getType();
7398 VectorType dstTy = transposeOp.getType();
7399
7400 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7401 int64_t rank = srcTy.getRank();
7402
7403 // Build inverse permutation to map destination indices back to source.
7404 SmallVector<int64_t> inversePerm(rank, 0);
7405 for (int64_t i = 0; i < rank; ++i)
7406 inversePerm[permutation[i]] = i;
7407
7408 ArrayRef<int64_t> srcShape = srcTy.getShape();
7409 ArrayRef<int64_t> dstShape = dstTy.getShape();
7410 SmallVector<int64_t> srcIdx(rank, 0);
7411 SmallVector<int64_t> dstIdx(rank, 0);
7412 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
7413 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
7414
7415 auto elementsOld = fromElementsOp.getElements();
7416 SmallVector<Value> elementsNew;
7417 int64_t dstNumElements = dstTy.getNumElements();
7418 elementsNew.reserve(dstNumElements);
7419
7420 // For each element in destination row-major order, pick the corresponding
7421 // source element.
7422 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7423 // Pick the destination element index.
7424 dstIdx = delinearize(linearIdx, dstStrides);
7425 // Map the destination element index to the source element index.
7426 for (int64_t j = 0; j < rank; ++j)
7427 srcIdx[j] = dstIdx[inversePerm[j]];
7428 // Linearize the source element index.
7429 int64_t srcLin = linearize(srcIdx, srcStrides);
7430 // Add the source element to the new elements.
7431 elementsNew.push_back(elementsOld[srcLin]);
7432 }
7433
7434 rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
7435 elementsNew);
7436 return success();
7437 }
7438};
7439
7440/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
7441/// 'order preserving', where 'order preserving' means the flattened
7442/// inputs and outputs of the transpose have identical (numerical) values.
7443///
7444/// Example:
7445/// ```
7446/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
7447/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
7448/// to vector<8x1xi32>
7449/// ```
7450/// can be rewritten as the equivalent
7451/// ```
7452/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
7453/// ```
7454/// The algorithm works by partitioning dimensions into groups that can be
7455/// locally permuted while preserving order, and checks that the transpose
7456/// only permutes within these groups.
7457///
7458/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
7459/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
7460/// broadcasting from 1x1x4x1x1x7.
7461/// ^^^ ^ ^^^ ^
7462/// groups: 0 1 2 3
7463/// Order preserving permutations for this example are ones that only permute
7464/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
7465class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
7466public:
7467 using Base::Base;
7468 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7469 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7470
7471 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7472 PatternRewriter &rewriter) const override {
7473
7474 vector::BroadcastOp broadcast =
7475 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
7476 if (!broadcast) {
7477 return rewriter.notifyMatchFailure(transpose,
7478 "not preceded by a broadcast");
7479 }
7480
7481 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
7482 VectorType outputType = transpose.getResultVectorType();
7483
7484 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
7485 bool inputIsScalar = !inputType;
7486 if (inputIsScalar) {
7487 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7488 broadcast.getSource());
7489 return success();
7490 }
7491
7492 ArrayRef<int64_t> permutation = transpose.getPermutation();
7493 ArrayRef<int64_t> inputShape = inputType.getShape();
7494 int64_t inputRank = inputType.getRank();
7495 int64_t outputRank = transpose.getType().getRank();
7496 int64_t deltaRank = outputRank - inputRank;
7497
7498 int low = 0;
7499 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7500 bool notOne = inputShape[inputIndex] != 1;
7501 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7502 bool groupEndFound = notOne || prevNotOne;
7503 if (groupEndFound) {
7504 int high = inputIndex + deltaRank;
7505 // Return failure if not all permutation destinations for indices in
7506 // [low, high) are in [low, high), i.e. the permutation is not local to
7507 // the group.
7508 for (int i = low; i < high; ++i) {
7509 if (permutation[i] < low || permutation[i] >= high) {
7510 return rewriter.notifyMatchFailure(
7511 transpose, "permutation not local to group");
7512 }
7513 }
7514 low = high;
7515 }
7516 }
7517
7518 // We don't need to check the final group [low, outputRank) because if it is
7519 // not locally bound, there must be a preceding group that already failed
7520 // the check (impossible to have just 1 non-locally bound group).
7521
7522 // The preceding logic also ensures that at this point, the output of the
7523 // transpose is definitely broadcastable from the input shape, assert so:
7524 assert(vector::isBroadcastableTo(inputType, outputType) ==
7525 vector::BroadcastableToResult::Success &&
7526 "not broadcastable directly to transpose output");
7527
7528 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7529 broadcast.getSource());
7530
7531 return success();
7532 }
7533};
7534
7535} // namespace
7536
7537void vector::TransposeOp::getCanonicalizationPatterns(
7538 RewritePatternSet &results, MLIRContext *context) {
7539 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7540 FoldTransposeSplat, FoldTransposeFromElements,
7541 FoldTransposeBroadcast>(context);
7542}
7543
7544//===----------------------------------------------------------------------===//
7545// ConstantMaskOp
7546//===----------------------------------------------------------------------===//
7547
7548void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
7549 VectorType type, ConstantMaskKind kind) {
7550 assert(kind == ConstantMaskKind::AllTrue ||
7551 kind == ConstantMaskKind::AllFalse);
7552 build(builder, result, type,
7553 kind == ConstantMaskKind::AllTrue
7554 ? type.getShape()
7555 : SmallVector<int64_t>(type.getRank(), 0));
7556}
7557
7558LogicalResult ConstantMaskOp::verify() {
7559 auto resultType = llvm::cast<VectorType>(getResult().getType());
7560 // Check the corner case of 0-D vectors first.
7561 if (resultType.getRank() == 0) {
7562 if (getMaskDimSizes().size() != 1)
7563 return emitError("array attr must have length 1 for 0-D vectors");
7564 auto dim = getMaskDimSizes()[0];
7565 if (dim != 0 && dim != 1)
7566 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
7567 return success();
7568 }
7569
7570 // Verify that array attr size matches the rank of the vector result.
7571 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
7572 return emitOpError(
7573 "must specify array attr of size equal vector result rank");
7574 // Verify that each array attr element is in bounds of corresponding vector
7575 // result dimension size.
7576 auto resultShape = resultType.getShape();
7577 auto resultScalableDims = resultType.getScalableDims();
7578 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7579 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7580 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7581 return emitOpError(
7582 "array attr of size out of bounds of vector result dimension size");
7583 if (resultScalableDims[index] && maskDimSize != 0 &&
7584 maskDimSize != resultShape[index])
7585 return emitOpError(
7586 "only supports 'none set' or 'all set' scalable dimensions");
7587 }
7588 // Verify that if one mask dim size is zero, they all should be zero (because
7589 // the mask region is a conjunction of each mask dimension interval).
7590 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7591 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
7592 if (anyZeros && !allZeros)
7593 return emitOpError("expected all mask dim sizes to be zeros, "
7594 "as a result of conjunction with zero mask dim");
7595 return success();
7596}
7597
7598bool ConstantMaskOp::isAllOnesMask() {
7599 auto resultType = getVectorType();
7600 // Check the corner case of 0-D vectors first.
7601 if (resultType.getRank() == 0) {
7602 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7603 return getMaskDimSizes()[0] == 1;
7604 }
7605 for (const auto [resultSize, maskDimSize] :
7606 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7607 if (maskDimSize < resultSize)
7608 return false;
7609 }
7610 return true;
7611}
7612
7613OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7614 ArrayRef<int64_t> bounds = getMaskDimSizes();
7615 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7616
7617 auto createBoolSplat = [&](bool x) {
7618 return SplatElementsAttr::get(getVectorType(),
7620 };
7621
7622 // Check the corner case of 0-D vectors first.
7623 if (vectorSizes.empty()) {
7624 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7625 return createBoolSplat(bounds[0] == 1);
7626 }
7627 // Fold vector.constant_mask to splat if possible.
7628 if (bounds == vectorSizes)
7629 return createBoolSplat(true);
7630 if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7631 return createBoolSplat(false);
7632 return OpFoldResult();
7633}
7634
7635//===----------------------------------------------------------------------===//
7636// CreateMaskOp
7637//===----------------------------------------------------------------------===//
7638
7639void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7640 VectorType type,
7641 ArrayRef<OpFoldResult> mixedOperands) {
7642 SmallVector<Value> operands =
7643 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7644 build(builder, result, type, operands);
7645}
7646
7647LogicalResult CreateMaskOp::verify() {
7648 auto vectorType = llvm::cast<VectorType>(getResult().getType());
7649 // Verify that an operand was specified for each result vector each dimension.
7650 if (vectorType.getRank() == 0) {
7651 if (getNumOperands() != 1)
7652 return emitOpError(
7653 "must specify exactly one operand for 0-D create_mask");
7654 } else if (getNumOperands() !=
7655 llvm::cast<VectorType>(getResult().getType()).getRank()) {
7656 return emitOpError(
7657 "must specify an operand for each result vector dimension");
7658 }
7659 return success();
7660}
7661
7662namespace {
7663
7664/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7665///
7666/// Ex 1:
7667/// %c2 = arith.constant 2 : index
7668/// %c3 = arith.constant 3 : index
7669/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7670/// Becomes:
7671/// vector.constant_mask [3, 2] : vector<4x3xi1>
7672///
7673/// Ex 2:
7674/// %c_neg_1 = arith.constant -1 : index
7675/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7676/// becomes:
7677/// vector.constant_mask [0] : vector<[8]xi1>
7678///
7679/// Ex 3:
7680/// %c8 = arith.constant 8 : index
7681/// %c16 = arith.constant 16 : index
7682/// %0 = vector.vscale
7683/// %1 = arith.muli %0, %c16 : index
7684/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7685/// becomes:
7686/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7687class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7688public:
7689 using Base::Base;
7690
7691 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7692 PatternRewriter &rewriter) const override {
7693 VectorType maskType = createMaskOp.getVectorType();
7694 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7695 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7696
7697 // Special case: Rank zero shape.
7698 constexpr std::array<int64_t, 1> rankZeroShape{1};
7699 constexpr std::array<bool, 1> rankZeroScalableDims{false};
7700 if (maskType.getRank() == 0) {
7701 maskTypeDimSizes = rankZeroShape;
7702 maskTypeDimScalableFlags = rankZeroScalableDims;
7703 }
7704
7705 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7706 // collect the `constantDims` (for the ConstantMaskOp).
7707 SmallVector<int64_t, 4> constantDims;
7708 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7709 if (auto intSize = getConstantIntValue(dimSize)) {
7710 // Constant value.
7711 // If the mask dim is non-scalable this can be any value.
7712 // If the mask dim is scalable only zero (all-false) is supported.
7713 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7714 return failure();
7715 constantDims.push_back(*intSize);
7716 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7717 // Constant vscale multiple (e.g. 4 x vscale).
7718 // Must be all-true to fold to a ConstantMask.
7719 if (vscaleMultiplier < maskTypeDimSizes[i])
7720 return failure();
7721 constantDims.push_back(*vscaleMultiplier);
7722 } else {
7723 return failure();
7724 }
7725 }
7726
7727 // Clamp values to constant_mask bounds.
7728 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7729 value = std::clamp<int64_t>(value, 0, maskDimSize);
7730
7731 // If one of dim sizes is zero, set all dims to zero.
7732 if (llvm::is_contained(constantDims, 0))
7733 constantDims.assign(constantDims.size(), 0);
7734
7735 // Replace 'createMaskOp' with ConstantMaskOp.
7736 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7737 constantDims);
7738 return success();
7739 }
7740};
7741
7742} // namespace
7743
7744void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7745 MLIRContext *context) {
7746 results.add<CreateMaskFolder>(context);
7747}
7748
7749//===----------------------------------------------------------------------===//
7750// MaskOp
7751//===----------------------------------------------------------------------===//
7752
7753void MaskOp::build(
7754 OpBuilder &builder, OperationState &result, Value mask,
7755 Operation *maskableOp,
7756 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7757 assert(maskRegionBuilder &&
7758 "builder callback for 'maskRegion' must be present");
7759
7760 result.addOperands(mask);
7761 OpBuilder::InsertionGuard guard(builder);
7762 Region *maskRegion = result.addRegion();
7763 builder.createBlock(maskRegion);
7764 maskRegionBuilder(builder, maskableOp);
7765}
7766
7767void MaskOp::build(
7768 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7769 Value mask, Operation *maskableOp,
7770 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7771 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7772 maskRegionBuilder);
7773}
7774
7775void MaskOp::build(
7776 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7777 Value mask, Value passthru, Operation *maskableOp,
7778 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7779 build(builder, result, mask, maskableOp, maskRegionBuilder);
7780 if (passthru)
7781 result.addOperands(passthru);
7782 result.addTypes(resultTypes);
7783}
7784
7785ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7786 // Create the op region.
7787 result.regions.reserve(1);
7788 Region &maskRegion = *result.addRegion();
7789
7790 auto &builder = parser.getBuilder();
7791
7792 // Parse all the operands.
7793 OpAsmParser::UnresolvedOperand mask;
7794 if (parser.parseOperand(mask))
7795 return failure();
7796
7797 // Optional passthru operand.
7798 OpAsmParser::UnresolvedOperand passthru;
7799 ParseResult parsePassthru = parser.parseOptionalComma();
7800 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7801 return failure();
7802
7803 // Parse op region.
7804 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7805 return failure();
7806
7807 MaskOp::ensureTerminator(maskRegion, builder, result.location);
7808
7809 // Parse the optional attribute list.
7810 if (parser.parseOptionalAttrDict(result.attributes))
7811 return failure();
7812
7813 // Parse all the types.
7814 Type maskType;
7815 if (parser.parseColonType(maskType))
7816 return failure();
7817
7818 SmallVector<Type> resultTypes;
7819 if (parser.parseOptionalArrowTypeList(resultTypes))
7820 return failure();
7821 result.types.append(resultTypes);
7822
7823 // Resolve operands.
7824 if (parser.resolveOperand(mask, maskType, result.operands))
7825 return failure();
7826
7827 if (parsePassthru.succeeded()) {
7828 if (resultTypes.empty())
7829 return parser.emitError(
7830 parser.getNameLoc(),
7831 "expects a result if passthru operand is provided");
7832
7833 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7834 return failure();
7835 }
7836
7837 return success();
7838}
7839
7840void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7841 p << " " << getMask();
7842 if (getPassthru())
7843 p << ", " << getPassthru();
7844
7845 // Print single masked operation and skip terminator.
7846 p << " { ";
7847 Block *singleBlock = &getMaskRegion().getBlocks().front();
7848 if (singleBlock && !singleBlock->getOperations().empty())
7849 p.printCustomOrGenericOp(&singleBlock->front());
7850 p << " }";
7851
7852 p.printOptionalAttrDict(getOperation()->getAttrs());
7853
7854 p << " : " << getMask().getType();
7855 if (getNumResults() > 0)
7856 p << " -> " << getResultTypes();
7857}
7858
7859void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7860 // 1. For an empty `vector.mask`, create a default terminator.
7861 if (region.empty() || region.front().empty()) {
7862 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7863 MaskOp>::ensureTerminator(region, builder, loc);
7864 return;
7865 }
7866
7867 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7868 Block &block = region.front();
7869 if (isa<vector::YieldOp>(block.back()))
7870 return;
7871
7872 // 3. For a non-empty `vector.mask` without an explicit terminator:
7873
7874 // Create default terminator if the number of masked operations is not
7875 // one. This case will trigger a verification failure.
7876 if (block.getOperations().size() != 1) {
7877 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7878 MaskOp>::ensureTerminator(region, builder, loc);
7879 return;
7880 }
7881
7882 // Create a terminator that yields the results from the masked operation.
7883 OpBuilder opBuilder(builder.getContext());
7884 Operation *maskedOp = &block.front();
7885 opBuilder.setInsertionPointToEnd(&block);
7886 vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7887}
7888
7889LogicalResult MaskOp::verify() {
7890 // Structural checks.
7891 Block &block = getMaskRegion().getBlocks().front();
7892 if (block.getOperations().empty())
7893 return emitOpError("expects a terminator within the mask region");
7894
7895 unsigned numMaskRegionOps = block.getOperations().size();
7896 if (numMaskRegionOps > 2)
7897 return emitOpError("expects only one operation to mask");
7898
7899 // Terminator checks.
7900 auto terminator = dyn_cast<vector::YieldOp>(block.back());
7901 if (!terminator)
7902 return emitOpError("expects a terminator within the mask region");
7903
7904 if (terminator->getNumOperands() != getNumResults())
7905 return emitOpError(
7906 "expects number of results to match mask region yielded values");
7907
7908 // Empty vector.mask. Nothing else to check.
7909 if (numMaskRegionOps == 1)
7910 return success();
7911
7912 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7913 if (!maskableOp)
7914 return emitOpError("expects a MaskableOpInterface within the mask region");
7915
7916 // Result checks.
7917 if (maskableOp->getNumResults() != getNumResults())
7918 return emitOpError("expects number of results to match maskable operation "
7919 "number of results");
7920
7921 if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7922 return emitOpError("expects all the results from the MaskableOpInterface "
7923 "to match all the values returned by the terminator");
7924
7925 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7926 return emitOpError(
7927 "expects result type to match maskable operation result type");
7928
7929 if (llvm::count_if(maskableOp->getResultTypes(),
7930 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7931 return emitOpError("multiple vector results not supported");
7932
7933 // Mask checks.
7934 Type expectedMaskType = maskableOp.getExpectedMaskType();
7935 if (getMask().getType() != expectedMaskType)
7936 return emitOpError("expects a ")
7937 << expectedMaskType << " mask for the maskable operation";
7938
7939 // Passthru checks.
7940 Value passthru = getPassthru();
7941 if (passthru) {
7942 if (!maskableOp.supportsPassthru())
7943 return emitOpError(
7944 "doesn't expect a passthru argument for this maskable operation");
7945
7946 if (maskableOp->getNumResults() != 1)
7947 return emitOpError("expects result when passthru argument is provided");
7948
7949 if (passthru.getType() != maskableOp->getResultTypes()[0])
7950 return emitOpError("expects passthru type to match result type");
7951 }
7952
7953 return success();
7954}
7955
7956/// Folds empty `vector.mask` with no passthru operand and with or without
7957/// return values. For example:
7958///
7959/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7960/// vector<8xi1> -> vector<8xf32>
7961/// %1 = user_op %0 : vector<8xf32>
7962///
7963/// becomes:
7964///
7965/// %0 = user_op %a : vector<8xf32>
7966///
7967/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7968/// as it requires creating new operations.
7969
7970static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7971 SmallVectorImpl<OpFoldResult> &results) {
7972 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7973 return failure();
7974
7975 Block *block = maskOp.getMaskBlock();
7976 auto terminator = cast<vector::YieldOp>(block->front());
7977 if (terminator.getNumOperands() == 0)
7978 return failure();
7979
7980 // `vector.mask` has results, propagate the results.
7981 llvm::append_range(results, terminator.getOperands());
7982 return success();
7983}
7984
7985LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7986 SmallVectorImpl<OpFoldResult> &results) {
7987 if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7988 return success();
7989
7990 MaskFormat maskFormat = getMaskFormat(getMask());
7991 if (maskFormat != MaskFormat::AllTrue)
7992 return failure();
7993
7994 // Move maskable operation outside of the `vector.mask` region.
7995 // If there is no maskable op (empty body), the fold cannot proceed; the
7996 // canonicalizer handles this case instead.
7997 Operation *maskableOp = getMaskableOp();
7998 if (!maskableOp)
7999 return failure();
8000 maskableOp->dropAllUses();
8001 maskableOp->moveBefore(getOperation());
8002
8003 llvm::append_range(results, maskableOp->getResults());
8004 return success();
8005}
8006
8007/// Canonialize empty `vector.mask` operations that can't be handled in
8008/// `VectorMask::fold` as they require creating new operations.
8009///
8010/// Example 1: Empty `vector.mask` with passthru operand.
8011///
8012/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
8013/// vector<8xi1> -> vector<8xf32>
8014///
8015/// becomes:
8016///
8017/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
8018///
8019class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
8020 using Base::Base;
8021
8022 LogicalResult matchAndRewrite(MaskOp maskOp,
8023 PatternRewriter &rewriter) const override {
8024 if (!maskOp.isEmpty())
8025 return failure();
8026
8027 if (!maskOp.hasPassthru())
8028 return failure();
8029
8030 // arith.select with a vector condition requires the value types to be
8031 // vectors of the same shape. Since vector.mask always has a vector mask
8032 // type, bail out when any result type doesn't match the mask shape to
8033 // avoid creating invalid IR.
8034 VectorType maskType = maskOp.getMask().getType();
8035 for (Type resultType : maskOp.getResultTypes()) {
8036 auto vecResultType = dyn_cast<VectorType>(resultType);
8037 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
8038 return failure();
8039 }
8040
8041 Block *block = maskOp.getMaskBlock();
8042 auto terminator = cast<vector::YieldOp>(block->front());
8043 assert(terminator.getNumOperands() == 1 &&
8044 "expected one result when passthru is provided");
8045
8046 rewriter.replaceOpWithNewOp<arith::SelectOp>(
8047 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
8048 terminator.getOperand(0), maskOp.getPassthru());
8049
8050 return success();
8051 }
8052};
8053
8054void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
8055 MLIRContext *context) {
8056 results.add<CanonializeEmptyMaskOp>(context);
8057}
8058
8059// MaskingOpInterface definitions.
8060
8061/// Returns the operation masked by this 'vector.mask'.
8062Operation *MaskOp::getMaskableOp() {
8063 Block *block = getMaskBlock();
8064 if (block->getOperations().size() < 2)
8065 return nullptr;
8066
8067 return &block->front();
8068}
8069
8070/// Returns true if 'vector.mask' has a passthru value.
8071bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
8072
8073//===----------------------------------------------------------------------===//
8074// ScanOp
8075//===----------------------------------------------------------------------===//
8076
8077LogicalResult ScanOp::verify() {
8078 VectorType srcType = getSourceType();
8079 VectorType initialType = getInitialValueType();
8080 // Check reduction dimension < rank.
8081 int64_t srcRank = srcType.getRank();
8082 int64_t reductionDim = getReductionDim();
8083 if (reductionDim >= srcRank)
8084 return emitOpError("reduction dimension ")
8085 << reductionDim << " has to be less than " << srcRank;
8086
8087 // Check that rank(initial_value) = rank(src) - 1.
8088 int64_t initialValueRank = initialType.getRank();
8089 if (initialValueRank != srcRank - 1)
8090 return emitOpError("initial value rank ")
8091 << initialValueRank << " has to be equal to " << srcRank - 1;
8092
8093 // Check shapes of initial value and src.
8094 ArrayRef<int64_t> srcShape = srcType.getShape();
8095 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
8096 SmallVector<int64_t> expectedShape;
8097 for (int i = 0; i < srcRank; i++) {
8098 if (i != reductionDim)
8099 expectedShape.push_back(srcShape[i]);
8100 }
8101 if (!llvm::equal(initialValueShapes, expectedShape)) {
8102 return emitOpError("incompatible input/initial value shapes");
8103 }
8104
8105 // Verify supported reduction kind.
8106 Type eltType = getDestType().getElementType();
8107 if (!isSupportedCombiningKind(getKind(), eltType))
8108 return emitOpError("unsupported reduction type ")
8109 << eltType << " for kind '" << stringifyCombiningKind(getKind())
8110 << "'";
8111
8112 return success();
8113}
8114
8116 RewritePatternSet &patterns, PatternBenefit benefit) {
8117 patterns
8118 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
8119 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
8120 StridedSliceConstantMaskFolder, TransposeFolder>(
8121 patterns.getContext(), benefit);
8122}
8123
8124Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
8125 CombiningKind kind, Value v1, Value acc,
8126 arith::FastMathFlagsAttr fastmath,
8127 Value mask) {
8128 Type t1 = getElementTypeOrSelf(v1.getType());
8129 Type tAcc = getElementTypeOrSelf(acc.getType());
8130 Value result;
8131
8132 switch (kind) {
8133 case CombiningKind::ADD:
8134 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
8135 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
8136 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8137 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
8138 else
8139 llvm_unreachable("invalid value types for ADD reduction");
8140 break;
8141 case CombiningKind::AND:
8142 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8143 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
8144 break;
8145 case CombiningKind::MAXNUMF:
8146 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8147 "expected float values");
8148 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
8149 break;
8150 case CombiningKind::MAXIMUMF:
8151 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8152 "expected float values");
8153 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
8154 break;
8155 case CombiningKind::MINNUMF:
8156 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8157 "expected float values");
8158 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
8159 break;
8160 case CombiningKind::MINIMUMF:
8161 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8162 "expected float values");
8163 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
8164 break;
8165 case CombiningKind::MAXSI:
8166 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8167 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
8168 break;
8169 case CombiningKind::MINSI:
8170 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8171 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
8172 break;
8173 case CombiningKind::MAXUI:
8174 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8175 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
8176 break;
8177 case CombiningKind::MINUI:
8178 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8179 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
8180 break;
8181 case CombiningKind::MUL:
8182 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
8183 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
8184 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8185 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
8186 else
8187 llvm_unreachable("invalid value types for MUL reduction");
8188 break;
8189 case CombiningKind::OR:
8190 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8191 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
8192 break;
8193 case CombiningKind::XOR:
8194 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8195 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
8196 break;
8197 };
8198
8199 assert(result && "unknown CombiningKind");
8200 return selectPassthru(b, mask, result, acc);
8201}
8202
8203//===----------------------------------------------------------------------===//
8204// StepOp
8205//===----------------------------------------------------------------------===//
8206
8207void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
8208 SetIntRangeFn setResultRanges) {
8209 auto resultType = cast<VectorType>(getType());
8210 if (resultType.isScalable()) {
8211 return;
8212 }
8213 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
8214 APInt zero(bitwidth, 0);
8215 APInt high(bitwidth, resultType.getDimSize(0) - 1);
8216 ConstantIntRanges result = {zero, high, zero, high};
8217 setResultRanges(getResult(), result);
8218}
8219
8220namespace {
8221
8222/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
8223/// constant large enough such that the result is the same at all indices.
8224///
8225/// For example, rewrite the 'greater than' comparison below,
8226///
8227/// ```mlir
8228/// %cst = arith.constant dense<7> : vector<3xindex>
8229/// %stp = vector.step : vector<3xindex>
8230/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
8231/// ```
8232///
8233/// as,
8234///
8235/// ```mlir
8236/// %out = arith.constant dense<false> : vector<3xi1>.
8237/// ```
8238///
8239/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
8240/// is false at ALL indices we fold. If the constant was 1, then
8241/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
8242/// conservatively preferring the 'compact' vector.step representation.
8243///
8244/// Note: this folder only works for the case where the constant (`%cst` above)
8245/// is the second operand of the comparison. The arith.cmpi canonicalizer will
8246/// ensure that constants are always second (on the right).
8247struct StepCompareFolder : public OpRewritePattern<StepOp> {
8248 using Base::Base;
8249
8250 LogicalResult matchAndRewrite(StepOp stepOp,
8251 PatternRewriter &rewriter) const override {
8252 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
8253
8254 for (OpOperand &use : stepOp.getResult().getUses()) {
8255 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8256 if (!cmpiOp)
8257 continue;
8258
8259 // arith.cmpi canonicalizer makes constants final operands.
8260 const unsigned stepOperandNumber = use.getOperandNumber();
8261 if (stepOperandNumber != 0)
8262 continue;
8263
8264 // Check that operand 1 is a constant.
8265 unsigned constOperandNumber = 1;
8266 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8267 std::optional<int64_t> maybeConstValue =
8268 getConstantIntValue(otherOperand);
8269 if (!maybeConstValue.has_value())
8270 continue;
8271
8272 int64_t constValue = maybeConstValue.value();
8273 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8274
8275 auto maybeSplat = [&]() -> std::optional<bool> {
8276 // Handle ult (unsigned less than) and uge (unsigned greater equal).
8277 if ((pred == arith::CmpIPredicate::ult ||
8278 pred == arith::CmpIPredicate::uge) &&
8279 stepSize <= constValue)
8280 return pred == arith::CmpIPredicate::ult;
8281
8282 // Handle ule and ugt.
8283 if ((pred == arith::CmpIPredicate::ule ||
8284 pred == arith::CmpIPredicate::ugt) &&
8285 stepSize - 1 <= constValue) {
8286 return pred == arith::CmpIPredicate::ule;
8287 }
8288
8289 // Handle eq and ne.
8290 if ((pred == arith::CmpIPredicate::eq ||
8291 pred == arith::CmpIPredicate::ne) &&
8292 stepSize <= constValue)
8293 return pred == arith::CmpIPredicate::ne;
8294
8295 return std::nullopt;
8296 }();
8297
8298 if (!maybeSplat.has_value())
8299 continue;
8300
8301 rewriter.setInsertionPointAfter(cmpiOp);
8302
8303 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8304 if (!type)
8305 continue;
8306
8307 auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
8308 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8309 type, boolAttr);
8310
8311 rewriter.replaceOp(cmpiOp, splat);
8312 return success();
8313 }
8314
8315 return failure();
8316 }
8317};
8318} // namespace
8319
8320void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8321 MLIRContext *context) {
8322 results.add<StepCompareFolder>(context);
8323}
8324
8325//===----------------------------------------------------------------------===//
8326// Vector Masking Utilities
8327//===----------------------------------------------------------------------===//
8328
8329/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
8330/// as masked operation.
8331void mlir::vector::createMaskOpRegion(OpBuilder &builder,
8332 Operation *maskableOp) {
8333 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
8334 Block *insBlock = builder.getInsertionBlock();
8335 // Create a block and move the op to that block.
8336 insBlock->getOperations().splice(
8337 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
8338 YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
8339}
8340
8341/// Creates a vector.mask operation around a maskable operation. Returns the
8342/// vector.mask operation if the mask provided is valid. Otherwise, returns
8343/// the maskable operation itself.
8344Operation *mlir::vector::maskOperation(OpBuilder &builder,
8345 Operation *maskableOp, Value mask,
8346 Value passthru) {
8347 if (!mask)
8348 return maskableOp;
8349 if (passthru)
8350 return MaskOp::create(builder, maskableOp->getLoc(),
8351 maskableOp->getResultTypes(), mask, passthru,
8352 maskableOp, createMaskOpRegion);
8353 return MaskOp::create(builder, maskableOp->getLoc(),
8354 maskableOp->getResultTypes(), mask, maskableOp,
8356}
8357
8358/// Creates a vector select operation that picks values from `newValue` or
8359/// `passthru` for each result vector lane based on `mask`. This utility is used
8360/// to propagate the pass-thru value of vector.mask or for cases where only the
8361/// pass-thru value propagation is needed. VP intrinsics do not support
8362/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
8363/// usually able to match op + select patterns and fold them into a native
8364/// target instructions.
8365Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
8366 Value newValue, Value passthru) {
8367 if (!mask)
8368 return newValue;
8369
8370 return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
8371 mask, newValue, passthru);
8372}
8373
8374//===----------------------------------------------------------------------===//
8375// InterleaveOp
8376//===----------------------------------------------------------------------===//
8377
8378namespace {
8379
8380/// This folder works on the following round-trip identity:
8381/// interleave(deinterleave(x).even, deinterleave(x).odd) -> x
8382struct InterleaveDeinterleaveFolder : public OpRewritePattern<InterleaveOp> {
8383 using Base::Base;
8384
8385 LogicalResult matchAndRewrite(InterleaveOp interleaveOp,
8386 PatternRewriter &rewriter) const override {
8387 auto lhsDefOp = interleaveOp.getLhs().getDefiningOp<DeinterleaveOp>();
8388 auto rhsDefOp = interleaveOp.getRhs().getDefiningOp<DeinterleaveOp>();
8389 if (!lhsDefOp || !rhsDefOp || lhsDefOp != rhsDefOp)
8390 return failure();
8391 for (auto [idx, operand] : llvm::enumerate(interleaveOp.getOperands())) {
8392 if (cast<OpResult>(operand).getResultNumber() != idx)
8393 return failure();
8394 }
8395 rewriter.replaceOp(interleaveOp, lhsDefOp.getSource());
8396 return success();
8397 }
8398};
8399} // namespace
8400
8401void InterleaveOp::getCanonicalizationPatterns(RewritePatternSet &results,
8402 MLIRContext *context) {
8403 results.add<InterleaveDeinterleaveFolder>(context);
8404}
8405
8406std::optional<SmallVector<int64_t, 4>> InterleaveOp::getShapeForUnroll() {
8407 return llvm::to_vector<4>(getResultVectorType().getShape());
8408}
8409
8410//===----------------------------------------------------------------------===//
8411// DeinterleaveOp
8412//===----------------------------------------------------------------------===//
8413
8414std::optional<SmallVector<int64_t, 4>> DeinterleaveOp::getShapeForUnroll() {
8415 return llvm::to_vector<4>(getResultVectorType().getShape());
8416}
8417
8418//===----------------------------------------------------------------------===//
8419// TableGen'd op method definitions
8420//===----------------------------------------------------------------------===//
8421
8422#define GET_ATTRDEF_CLASSES
8423#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8424
8425#define GET_OP_CLASSES
8426#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition VectorOps.cpp:74
static OpFoldResult foldShuffleIdentityMask(ShuffleOp op)
Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static OpFoldResult foldShufflePoisonInputs(MLIRContext *context, Attribute v1Attr, Attribute v2Attr)
Fold shuffle poison, poison -> poison.
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
Definition VectorOps.cpp:64
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldShuffleConstantInputs(ShuffleOp op, Attribute v1Attr, Attribute v2Attr)
Fold a shuffle of constant 1-D inputs by evaluating the mask.
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static OpFoldResult foldShufflePoisonOperandToMask(ShuffleOp op)
If a shuffle operand is poison, replace all mask indices that reference it with kPoisonIndex.
static LogicalResult foldSize1TransferPermutationMap(TransferOp op)
When the vector type is vector<1xT>, the permutation map is irrelevant: the single vector lane always...
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
#define mul(a, b)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
bool empty()
Definition Block.h:158
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
iterator begin()
Definition Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:274
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Value getOperand(unsigned idx)
Definition Operation.h:375
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:859
void setOperand(unsigned idx, Value value)
Definition Operation.h:376
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
operand_type_range getOperandTypes()
Definition Operation.h:422
result_type_range getResultTypes()
Definition Operation.h:453
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:440
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isF32() const
Definition Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:114
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition VectorOps.h:64
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition VectorOps.h:72
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const