MLIR 23.0.0git
VectorOps.cpp
Go to the documentation of this file.
1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
15
28#include "mlir/IR/AffineExpr.h"
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/Builders.h"
34#include "mlir/IR/IRMapping.h"
38#include "mlir/IR/ValueRange.h"
41#include "mlir/Support/LLVM.h"
43#include "llvm/ADT/ArrayRef.h"
44#include "llvm/ADT/Repeated.h"
45#include "llvm/ADT/STLExtras.h"
46#include "llvm/ADT/SmallVector.h"
47#include "llvm/ADT/SmallVectorExtras.h"
48#include "llvm/ADT/StringSet.h"
49#include "llvm/ADT/TypeSwitch.h"
50#include "llvm/Support/Casting.h"
51
52#include <cassert>
53#include <cstdint>
54#include <numeric>
55
56#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
57// Pull in all enum type and utility function definitions.
58#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
59
60using namespace mlir;
61using namespace mlir::vector;
62
63/// Helper enum to classify mask value.
64enum class MaskFormat {
68};
69
70/// Helper method to classify a mask value. Currently, the method
71/// looks "under the hood" of a constant value with dense attributes
72/// and a constant mask operation (since the client may be called at
73/// various stages during progressive lowering).
75 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
76 // Inspect constant dense values. We count up for bits that
77 // are set, count down for bits that are cleared, and bail
78 // when a mix is detected.
79 if (auto denseElts = llvm::dyn_cast<DenseIntElementsAttr>(c.getValue())) {
80 int64_t val = 0;
81 for (bool b : denseElts.getValues<bool>())
82 if (b && val >= 0)
83 val++;
84 else if (!b && val <= 0)
85 val--;
86 else
88 if (val > 0)
90 if (val < 0)
92 }
93 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
94 // Inspect constant mask index. If the index exceeds the
95 // dimension size, all bits are set. If the index is zero
96 // or less, no bits are set.
97 ArrayRef<int64_t> masks = m.getMaskDimSizes();
98 auto shape = m.getType().getShape();
99 bool allTrue = true;
100 bool allFalse = true;
101 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
102 if (maskIdx < dimSize)
103 allTrue = false;
104 if (maskIdx > 0)
105 allFalse = false;
106 }
107 if (allTrue)
108 return MaskFormat::AllTrue;
109 if (allFalse)
111 } else if (auto m = mask.getDefiningOp<CreateMaskOp>()) {
112 // Finds all-false create_masks. An all-true create_mask requires all
113 // dims to be constants, so that'll be folded to a constant_mask, then
114 // detected in the constant_mask case.
115 auto maskOperands = m.getOperands();
116 for (Value operand : maskOperands) {
117 if (auto constantOp = operand.getDefiningOp<arith::ConstantOp>()) {
118 int64_t dimSize =
119 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
120 if (dimSize <= 0)
122 }
123 }
124 return MaskFormat::Unknown;
125 }
126 return MaskFormat::Unknown;
127}
128
129/// Default callback to build a region with a 'vector.yield' terminator with no
130/// arguments.
132 vector::YieldOp::create(builder, loc);
133}
134
135// Helper for verifying combining kinds in contractions and reductions.
136static bool isSupportedCombiningKind(CombiningKind combiningKind,
137 Type elementType) {
138 switch (combiningKind) {
139 case CombiningKind::ADD:
140 case CombiningKind::MUL:
141 return elementType.isIntOrIndexOrFloat();
142 case CombiningKind::MINUI:
143 case CombiningKind::MINSI:
144 case CombiningKind::MAXUI:
145 case CombiningKind::MAXSI:
146 case CombiningKind::AND:
147 case CombiningKind::OR:
148 case CombiningKind::XOR:
149 return elementType.isIntOrIndex();
150 case CombiningKind::MINNUMF:
151 case CombiningKind::MAXNUMF:
152 case CombiningKind::MINIMUMF:
153 case CombiningKind::MAXIMUMF:
154 return llvm::isa<FloatType>(elementType);
155 }
156 return false;
157}
158
159/// Returns the effective rank of the vector to read/write for Xfer Ops
160///
161/// When the element type of the shaped type is _a scalar_, this will simply
162/// return the rank of the vector ( the result for xfer_read or the value to
163/// store for xfer_write).
164///
165/// When the element type of the base shaped type is _a vector_, returns the
166/// difference between the original vector type and the element type of the
167/// shaped type.
168///
169/// EXAMPLE 1 (element type is _a scalar_):
170/// - shapedType = tensor<10x20xf32>, vectorType = vector<2x4xf32>
171/// - shapedType.getElementType() = f32 (rank 0)
172/// - vectorType.getRank() = 2
173/// - Result = 2 - 0 = 2
174///
175/// EXAMPLE 2 (element type is _a vector_):
176/// - shapedType = tensor<10xvector<20xf32>>, vectorType = vector<20xf32>
177/// - shapedType.getElementType() = vector<20xf32> (rank 1)
178/// - vectorType.getRank() = 1
179/// - Result = 1 - 1 = 0
180///
181/// This is used to determine the number of minor dimensions for identity maps
182/// in vector transfer Ops.
183static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType,
184 VectorType vectorType) {
185 unsigned elementVectorRank = 0;
186 VectorType elementVectorType =
187 llvm::dyn_cast<VectorType>(shapedType.getElementType());
188 if (elementVectorType)
189 elementVectorRank += elementVectorType.getRank();
190 return vectorType.getRank() - elementVectorRank;
191}
192
194 VectorType vectorType) {
195 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
196 // TODO: replace once we have 0-d vectors.
197 if (shapedType.getRank() == 0 &&
198 vectorType.getShape() == ArrayRef<int64_t>{1})
199 return AffineMap::get(
200 /*numDims=*/0, /*numSymbols=*/0,
201 getAffineConstantExpr(0, shapedType.getContext()));
203 shapedType.getRank(),
204 getEffectiveVectorRankForXferOp(shapedType, vectorType),
205 shapedType.getContext());
206}
207
208/// Check if `write` is of a constant splat and the masked `read` is padded with
209/// the same splat value -- meaning it could be the same value as the initial
210/// constant splat.
211static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write,
212 vector::TransferReadOp read) {
213 auto readMask = read.getMask();
214 auto writeMask = write.getMask();
215 // Check if the masks are consistent. The splat value could be the same if the
216 // read is masked (and padded with the splat value), and the write is unmasked
217 // or has the same mask. Note this does not allow the case where the write is
218 // masked and the read is unmasked, as then the read could be of more elements
219 // than the write (which may not be the same value).
220 bool couldBeSameSplat = readMask && (!writeMask || writeMask == readMask);
221 if (!couldBeSameSplat)
222 return false;
223 // Check for constant splat (as the source of the write).
224 DenseElementsAttr splatAttr;
225 if (!matchPattern(write.getVector(),
226 m_Constant<DenseElementsAttr>(&splatAttr)) ||
227 !splatAttr.isSplat()) {
228 return false;
229 }
230 // The padding of the read and the constant splat value must be the same.
231 Attribute padAttr;
232 if (!matchPattern(read.getPadding(), m_Constant(&padAttr)))
233 return false;
234 return padAttr == splatAttr.getSplatValue<Attribute>();
235}
236
237bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
238 vector::TransferReadOp read) {
239 return !defWrite.hasOutOfBoundsDim() &&
240 defWrite.getIndices() == read.getIndices() &&
241 defWrite.getVectorType() == read.getVectorType() &&
242 defWrite.getPermutationMap() == read.getPermutationMap() &&
243 ((!defWrite.getMask() && !read.getMask()) ||
245}
246
247bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
248 vector::TransferWriteOp priorWrite) {
249 return priorWrite.getIndices() == write.getIndices() &&
250 priorWrite.getMask() == write.getMask() &&
251 priorWrite.getVectorType() == write.getVectorType() &&
252 priorWrite.getPermutationMap() == write.getPermutationMap();
253}
254
256 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB,
257 bool testDynamicValueUsingBounds) {
258 // For simplicity only look at transfer of same type.
259 if (transferA.getVectorType() != transferB.getVectorType())
260 return false;
261 unsigned rankOffset = transferA.getLeadingShapedRank();
262 for (unsigned i = 0, e = transferA.getIndices().size(); i < e; i++) {
263 Value indexA = transferA.getIndices()[i];
264 Value indexB = transferB.getIndices()[i];
265 std::optional<int64_t> cstIndexA = getConstantIntValue(indexA);
266 std::optional<int64_t> cstIndexB = getConstantIntValue(indexB);
267
268 if (i < rankOffset) {
269 // For leading dimensions, if we can prove that index are different we
270 // know we are accessing disjoint slices.
271 if (cstIndexA.has_value() && cstIndexB.has_value()) {
272 if (*cstIndexA != *cstIndexB)
273 return true;
274 continue;
275 }
276 if (testDynamicValueUsingBounds) {
277 // First try to see if we can fully compose and simplify the affine
278 // expression as a fast track.
279 FailureOr<uint64_t> delta =
281 if (succeeded(delta) && *delta != 0)
282 return true;
283
284 FailureOr<bool> testEqual =
286 if (succeeded(testEqual) && !testEqual.value())
287 return true;
288 }
289 } else {
290 // For this dimension, we slice a part of the memref we need to make sure
291 // the intervals accessed don't overlap.
292 int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset);
293 if (cstIndexA.has_value() && cstIndexB.has_value()) {
294 int64_t distance = std::abs(*cstIndexA - *cstIndexB);
295 if (distance >= vectorDim)
296 return true;
297 continue;
298 }
299 if (testDynamicValueUsingBounds) {
300 // First try to see if we can fully compose and simplify the affine
301 // expression as a fast track.
302 FailureOr<int64_t> delta =
304 if (succeeded(delta) && std::abs(*delta) >= vectorDim)
305 return true;
306
307 FailureOr<int64_t> computeDelta =
309 if (succeeded(computeDelta)) {
310 if (std::abs(computeDelta.value()) >= vectorDim)
311 return true;
312 }
313 }
314 }
315 }
316 return false;
317}
318
319bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
320 VectorTransferOpInterface transferB,
321 bool testDynamicValueUsingBounds) {
322 if (transferA.getBase() != transferB.getBase())
323 return false;
324 return isDisjointTransferIndices(transferA, transferB,
325 testDynamicValueUsingBounds);
326}
327
328// Helper to iterate over n-D vector slice elements. Calculate the next
329// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
330// Modifies the `position` in place. Returns a failure when `position` becomes
331// the end position.
332static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
334 ArrayRef<int64_t> offsets) {
335 for (auto [posInDim, dimSize, offsetInDim] :
336 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
337 ++posInDim;
338 if (posInDim < dimSize + offsetInDim)
339 return success();
340
341 // Carry the overflow to the next loop iteration.
342 posInDim = offsetInDim;
343 }
344
345 return failure();
346}
347
348/// Returns the integer numbers in `values`. `values` are expected to be
349/// constant operations.
352 llvm::transform(values, std::back_inserter(ints), [](Value value) {
353 auto constOp = value.getDefiningOp<arith::ConstantIndexOp>();
354 assert(constOp && "Unexpected non-constant index");
355 return constOp.value();
356 });
357 return ints;
358}
359
360/// Returns the integer numbers in `foldResults`. `foldResults` are expected to
361/// be constant operations.
364 llvm::transform(
365 foldResults, std::back_inserter(ints), [](OpFoldResult foldResult) {
366 assert(isa<Attribute>(foldResult) && "Unexpected non-constant index");
367 return cast<IntegerAttr>(cast<Attribute>(foldResult)).getInt();
368 });
369 return ints;
370}
371
372/// Convert `foldResults` into Values. Integer attributes are converted to
373/// constant op.
375 ArrayRef<OpFoldResult> foldResults) {
376 SmallVector<Value> values;
377 llvm::transform(foldResults, std::back_inserter(values),
378 [&](OpFoldResult foldResult) {
379 if (auto attr = dyn_cast<Attribute>(foldResult))
381 builder, loc, cast<IntegerAttr>(attr).getInt())
382 .getResult();
383
384 return cast<Value>(foldResult);
385 });
386 return values;
387}
388
389std::optional<int64_t> vector::getConstantVscaleMultiplier(Value value) {
390 if (value.getDefiningOp<vector::VectorScaleOp>())
391 return 1;
392 auto mul = value.getDefiningOp<arith::MulIOp>();
393 if (!mul)
394 return {};
395 auto lhs = mul.getLhs();
396 auto rhs = mul.getRhs();
397 if (lhs.getDefiningOp<vector::VectorScaleOp>())
398 return getConstantIntValue(rhs);
399 if (rhs.getDefiningOp<vector::VectorScaleOp>())
400 return getConstantIntValue(lhs);
401 return {};
402}
403
404/// Converts numeric attributes to the expected type. Supports
405/// integer-to-integer and float-to-integer conversions. Returns the original
406/// attribute if no conversion is needed or supported.
407static Attribute convertNumericAttr(Attribute attr, Type expectedType) {
408 // Integer-to-integer conversion
409 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
410 if (auto intType = dyn_cast<IntegerType>(expectedType)) {
411 if (intAttr.getType() != expectedType)
412 return IntegerAttr::get(expectedType, intAttr.getInt());
413 }
414 return attr;
415 }
416
417 // Float-to-integer bitcast (preserves bit representation)
418 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
419 auto intType = dyn_cast<IntegerType>(expectedType);
420 if (!intType)
421 return attr;
422
423 APFloat floatVal = floatAttr.getValue();
424 APInt intVal = floatVal.bitcastToAPInt();
425 return IntegerAttr::get(expectedType, intVal);
426 }
427
428 return attr;
429}
430
431//===----------------------------------------------------------------------===//
432// CombiningKindAttr
433//===----------------------------------------------------------------------===//
434
435namespace mlir {
436namespace vector {
437namespace detail {
439 using KeyTy = uint64_t;
440
442
443 bool operator==(const KeyTy &key) const { return value == key; }
444
446 const KeyTy &key) {
447 return new (allocator.allocate<BitmaskEnumStorage>())
449 }
450
452};
453} // namespace detail
454} // namespace vector
455} // namespace mlir
456
457//===----------------------------------------------------------------------===//
458// VectorDialect
459//===----------------------------------------------------------------------===//
460
461namespace {
462/// This class defines the interface for handling inlining with vector dialect
463/// operations.
464struct VectorInlinerInterface : public DialectInlinerInterface {
465 using DialectInlinerInterface::DialectInlinerInterface;
466
467 /// All vector dialect ops can be inlined.
468 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
469 return true;
470 }
471};
472} // namespace
473
474void VectorDialect::initialize() {
475 addAttributes<
476#define GET_ATTRDEF_LIST
477#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
478 >();
479
480 addOperations<
481#define GET_OP_LIST
482#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
483 >();
484
485 addInterfaces<VectorInlinerInterface>();
486
487 declarePromisedInterfaces<memref::IndexedAccessOpInterface, LoadOp, StoreOp,
488 MaskedLoadOp, MaskedStoreOp, ExpandLoadOp,
489 CompressStoreOp>();
490 declarePromisedInterfaces<bufferization::BufferizableOpInterface,
491 TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
492 YieldOp>();
493 declarePromisedInterfaces<SubsetOpInterface, TransferReadOp,
494 TransferWriteOp>();
495 declarePromisedInterface<SubsetExtractionOpInterface, TransferReadOp>();
496 declarePromisedInterface<SubsetInsertionOpInterface, TransferWriteOp>();
497 declarePromisedInterface<ConvertToLLVMPatternInterface, VectorDialect>();
498}
499
500/// Materialize a single constant operation from a given attribute value with
501/// the desired resultant type.
502Operation *VectorDialect::materializeConstant(OpBuilder &builder,
503 Attribute value, Type type,
504 Location loc) {
505 if (matchPattern(value, ub::m_Poison()))
506 return value.getDialect().materializeConstant(builder, value, type, loc);
507
508 return arith::ConstantOp::materialize(builder, value, type, loc);
509}
510
512 return builder.getIntegerType(64);
513}
514
516 ArrayRef<int64_t> values) {
517 return builder.getI64ArrayAttr(values);
518}
519
520//===----------------------------------------------------------------------===//
521// MultiDimReductionOp
522//===----------------------------------------------------------------------===//
523
524void vector::MultiDimReductionOp::build(OpBuilder &builder,
525 OperationState &result, Value source,
526 Value acc, ArrayRef<bool> reductionMask,
527 CombiningKind kind) {
528 SmallVector<int64_t> reductionDims;
529 for (const auto &en : llvm::enumerate(reductionMask))
530 if (en.value())
531 reductionDims.push_back(en.index());
532 build(builder, result, kind, source, acc, reductionDims);
533}
534
535OpFoldResult MultiDimReductionOp::fold(FoldAdaptor adaptor) {
536 // No reduction dims: this is a noop regardless of rank.
537 if (getReductionDims().empty())
538 return getSource();
539 return {};
540}
541
542std::optional<SmallVector<int64_t, 4>>
543MultiDimReductionOp::getShapeForUnroll() {
544 return llvm::to_vector<4>(getSourceVectorType().getShape());
545}
546
547LogicalResult MultiDimReductionOp::verify() {
548 SmallVector<int64_t> targetShape;
549 SmallVector<bool> scalableDims;
550 Type inferredReturnType;
551 auto sourceScalableDims = getSourceVectorType().getScalableDims();
552 for (auto [dimIdx, dimSize] :
553 llvm::enumerate(getSourceVectorType().getShape()))
554 if (!llvm::any_of(getReductionDims(),
555 [dimIdx = dimIdx](int64_t reductionDimIdx) {
556 return reductionDimIdx == static_cast<int64_t>(dimIdx);
557 })) {
558 targetShape.push_back(dimSize);
559 scalableDims.push_back(sourceScalableDims[dimIdx]);
560 }
561 // TODO: update to also allow 0-d vectors when available.
562 if (targetShape.empty())
563 inferredReturnType = getSourceVectorType().getElementType();
564 else
565 inferredReturnType = VectorType::get(
566 targetShape, getSourceVectorType().getElementType(), scalableDims);
567 if (getType() != inferredReturnType)
568 return emitOpError() << "destination type " << getType()
569 << " is incompatible with source type "
570 << getSourceVectorType();
571
572 return success();
573}
574
575/// Returns the mask type expected by this operation.
576Type MultiDimReductionOp::getExpectedMaskType() {
577 auto vecType = getSourceVectorType();
578 return VectorType::get(vecType.getShape(),
579 IntegerType::get(vecType.getContext(), /*width=*/1),
580 vecType.getScalableDims());
581}
582
583namespace {
584// Only unit dimensions that are being reduced are folded. If the dimension is
585// unit, but not reduced, it is not folded, thereby keeping the output type the
586// same. If not all dimensions which are reduced are of unit dimension, this
587// transformation does nothing. This is just a generalization of
588// ElideSingleElementReduction for ReduceOp.
589struct ElideUnitDimsInMultiDimReduction
590 : public OpRewritePattern<MultiDimReductionOp> {
591 using Base::Base;
592
593 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
594 PatternRewriter &rewriter) const override {
595 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
596 for (const auto &dim : enumerate(shape)) {
597 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
598 return failure();
599 }
600
601 // Vector mask setup.
602 OpBuilder::InsertionGuard guard(rewriter);
603 Operation *rootOp;
604 Value mask;
605 if (reductionOp.isMasked()) {
606 rewriter.setInsertionPoint(reductionOp.getMaskingOp());
607 rootOp = reductionOp.getMaskingOp();
608 mask = reductionOp.getMaskingOp().getMask();
609 } else {
610 rootOp = reductionOp;
611 }
612
613 Location loc = reductionOp.getLoc();
614 Value acc = reductionOp.getAcc();
615 Value cast;
616 if (auto dstVecType = dyn_cast<VectorType>(reductionOp.getDestType())) {
617 if (mask) {
618 VectorType newMaskType =
619 VectorType::get(dstVecType.getShape(), rewriter.getI1Type(),
620 dstVecType.getScalableDims());
621 mask = vector::ShapeCastOp::create(rewriter, loc, newMaskType, mask);
622 }
623 cast = vector::ShapeCastOp::create(
624 rewriter, loc, reductionOp.getDestType(), reductionOp.getSource());
625 } else {
626 // This means we are reducing all the dimensions, and all reduction
627 // dimensions are of size 1. So a simple extraction would do.
628 if (mask)
629 mask = vector::ExtractOp::create(rewriter, loc, mask);
630 cast = vector::ExtractOp::create(rewriter, loc, reductionOp.getSource());
631 }
632
633 Value result =
634 vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), acc,
635 cast, /*fastmath=*/nullptr, mask);
636 rewriter.replaceOp(rootOp, result);
637 return success();
638 }
639};
640} // namespace
641
642void MultiDimReductionOp::getCanonicalizationPatterns(
643 RewritePatternSet &results, MLIRContext *context) {
644 results.add<ElideUnitDimsInMultiDimReduction>(context);
645}
646
647//===----------------------------------------------------------------------===//
648// ReductionOp
649//===----------------------------------------------------------------------===//
650
651void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
652 CombiningKind kind, Value vector,
653 arith::FastMathFlags fastMathFlags) {
654 build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags);
655}
656
657void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
658 CombiningKind kind, Value vector, Value acc,
659 arith::FastMathFlags fastMathFlags) {
660 build(builder, result,
661 llvm::cast<VectorType>(vector.getType()).getElementType(), kind, vector,
662 acc, fastMathFlags);
663}
664
665LogicalResult ReductionOp::verify() {
666 // Verify for 0-D and 1-D vector.
667 int64_t rank = getSourceVectorType().getRank();
668 if (rank > 1)
669 return emitOpError("unsupported reduction rank: ") << rank;
670
671 // Verify supported reduction kind.
672 Type eltType = getDest().getType();
673 if (!isSupportedCombiningKind(getKind(), eltType))
674 return emitOpError("unsupported reduction type '")
675 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
676 << "'";
677
678 return success();
679}
680
681// MaskableOpInterface methods.
682
683/// Returns the mask type expected by this operation.
684Type ReductionOp::getExpectedMaskType() {
685 auto vecType = getSourceVectorType();
686 return VectorType::get(vecType.getShape(),
687 IntegerType::get(vecType.getContext(), /*width=*/1),
688 vecType.getScalableDims());
689}
690
692 OpBuilder &builder, Location loc,
693 Value vector) {
694 switch (op) {
695 case arith::AtomicRMWKind::addf:
696 case arith::AtomicRMWKind::addi:
697 return vector::ReductionOp::create(builder, vector.getLoc(),
698 CombiningKind::ADD, vector);
699 case arith::AtomicRMWKind::mulf:
700 case arith::AtomicRMWKind::muli:
701 return vector::ReductionOp::create(builder, vector.getLoc(),
702 CombiningKind::MUL, vector);
703 case arith::AtomicRMWKind::minimumf:
704 return vector::ReductionOp::create(builder, vector.getLoc(),
705 CombiningKind::MINIMUMF, vector);
706 case arith::AtomicRMWKind::mins:
707 return vector::ReductionOp::create(builder, vector.getLoc(),
708 CombiningKind::MINSI, vector);
709 case arith::AtomicRMWKind::minu:
710 return vector::ReductionOp::create(builder, vector.getLoc(),
711 CombiningKind::MINUI, vector);
712 case arith::AtomicRMWKind::maximumf:
713 return vector::ReductionOp::create(builder, vector.getLoc(),
714 CombiningKind::MAXIMUMF, vector);
715 case arith::AtomicRMWKind::maxs:
716 return vector::ReductionOp::create(builder, vector.getLoc(),
717 CombiningKind::MAXSI, vector);
718 case arith::AtomicRMWKind::maxu:
719 return vector::ReductionOp::create(builder, vector.getLoc(),
720 CombiningKind::MAXUI, vector);
721 case arith::AtomicRMWKind::andi:
722 return vector::ReductionOp::create(builder, vector.getLoc(),
723 CombiningKind::AND, vector);
724 case arith::AtomicRMWKind::ori:
725 return vector::ReductionOp::create(builder, vector.getLoc(),
726 CombiningKind::OR, vector);
727 case arith::AtomicRMWKind::minnumf:
728 return vector::ReductionOp::create(builder, vector.getLoc(),
729 CombiningKind::MINNUMF, vector);
730 case arith::AtomicRMWKind::maxnumf:
731 return vector::ReductionOp::create(builder, vector.getLoc(),
732 CombiningKind::MAXNUMF, vector);
733 case arith::AtomicRMWKind::xori:
734 return vector::ReductionOp::create(builder, vector.getLoc(),
735 CombiningKind::XOR, vector);
736 default:
737 (void)emitOptionalError(loc, "Reduction operation type not supported");
738 break;
739 }
740 return nullptr;
741}
742
743std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
744 return llvm::to_vector<4>(getSourceVectorType().getShape());
745}
746
747namespace {
748struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
749 using Base::Base;
750
751 LogicalResult matchAndRewrite(ReductionOp reductionOp,
752 PatternRewriter &rewriter) const override {
753 // Vector mask setup.
754 OpBuilder::InsertionGuard guard(rewriter);
755 auto maskableOp =
756 cast<vector::MaskableOpInterface>(reductionOp.getOperation());
757 Operation *rootOp;
758 Value mask;
759 if (maskableOp.isMasked()) {
760 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
761 rootOp = maskableOp.getMaskingOp();
762 mask = maskableOp.getMaskingOp().getMask();
763 } else {
764 rootOp = reductionOp;
765 }
766
767 auto vectorType = reductionOp.getSourceVectorType();
768 if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
769 return failure();
770
771 Location loc = reductionOp.getLoc();
772 if (mask)
773 mask = ExtractOp::create(rewriter, loc, mask);
774 Value result = ExtractOp::create(rewriter, loc, reductionOp.getVector());
775
776 if (Value acc = reductionOp.getAcc())
777 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
778 result, acc,
779 reductionOp.getFastmathAttr(), mask);
780
781 rewriter.replaceOp(rootOp, result);
782 return success();
783 }
784};
785} // namespace
786
787void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
788 MLIRContext *context) {
789 results.add<ElideSingleElementReduction>(context);
790}
791
792//===----------------------------------------------------------------------===//
793// ContractionOp
794//===----------------------------------------------------------------------===//
795
796void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
798 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
799 ArrayRef<IteratorType> iteratorTypes) {
800 result.addOperands({lhs, rhs, acc});
801 result.addTypes(acc.getType());
802 result.addAttribute(
803 getIndexingMapsAttrName(result.name),
804 builder.getAffineMapArrayAttr(
805 AffineMap::inferFromExprList(indexingExprs, builder.getContext())));
806 result.addAttribute(
807 getIteratorTypesAttrName(result.name),
808 builder.getArrayAttr(llvm::map_to_vector(
809 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
810 return IteratorTypeAttr::get(builder.getContext(), t);
811 })));
812}
813
814void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
816 ArrayAttr indexingMaps,
817 ArrayAttr iteratorTypes) {
818 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
819 ContractionOp::getDefaultKind());
820}
821
822void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
824 ArrayAttr indexingMaps,
825 ArrayAttr iteratorTypes, CombiningKind kind,
826 arith::FastMathFlags fastMathFlags) {
827 result.addOperands({lhs, rhs, acc});
828 result.addTypes(acc.getType());
829 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
830 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
831 result.addAttribute(getKindAttrName(result.name),
832 CombiningKindAttr::get(builder.getContext(), kind));
833 if (fastMathFlags != arith::FastMathFlags::none)
834 result.addAttribute(
835 getFastmathAttrName(result.name),
836 arith::FastMathFlagsAttr::get(builder.getContext(), fastMathFlags));
837}
838
839ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
845 Type resultType;
846 auto loc = parser.getCurrentLocation();
847 DictionaryAttr dictAttr;
848 // TODO: Unify linalg op attribute parsing.
849 if (parser.parseAttribute(dictAttr) || parser.parseOperand(lhsInfo) ||
850 parser.parseComma() || parser.parseOperand(rhsInfo) ||
851 parser.parseComma() || parser.parseOperand(accInfo) ||
852 parser.parseTrailingOperandList(masksInfo) ||
853 parser.parseOptionalAttrDict(result.attributes) ||
854 parser.parseColonTypeList(types) ||
855 parser.parseKeywordType("into", resultType) ||
856 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
857 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
858 parser.resolveOperand(accInfo, resultType, result.operands) ||
859 parser.addTypeToList(resultType, result.types))
860 return failure();
861 result.attributes.append(dictAttr.getValue().begin(),
862 dictAttr.getValue().end());
863
864 // Convert array of string into an array of IteratyType enums. This is needed,
865 // because tests still use the old format when 'iterator_types' attribute is
866 // represented as an array of strings.
867 // TODO: Remove this conversion once tests are fixed.
868 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
869 result.attributes.get(getIteratorTypesAttrName(result.name)));
870 if (!iteratorTypes) {
871 return parser.emitError(loc)
872 << "expected " << getIteratorTypesAttrName(result.name)
873 << " array attribute";
874 }
875
876 SmallVector<Attribute> iteratorTypeAttrs;
877
878 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
879 auto maybeIteratorType = symbolizeIteratorType(s);
880 if (!maybeIteratorType.has_value())
881 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
882
883 iteratorTypeAttrs.push_back(
884 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
885 }
886 result.attributes.set(getIteratorTypesAttrName(result.name),
887 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
888
889 if (!result.attributes.get(getKindAttrName(result.name))) {
890 result.addAttribute(
891 getKindAttrName(result.name),
892 CombiningKindAttr::get(result.getContext(),
893 ContractionOp::getDefaultKind()));
894 }
895 if (masksInfo.empty())
896 return success();
897 if (masksInfo.size() != 2)
898 return parser.emitError(parser.getNameLoc(),
899 "expected zero or exactly 2 vector mask operands");
900 auto lhsType = llvm::cast<VectorType>(types[0]);
901 auto rhsType = llvm::cast<VectorType>(types[1]);
902 auto maskElementType = parser.getBuilder().getI1Type();
903 std::array<VectorType, 2> maskTypes = {
904 VectorType::Builder(lhsType).setElementType(maskElementType),
905 VectorType::Builder(rhsType).setElementType(maskElementType)};
906 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
907 return failure();
908 return success();
909}
910
911void ContractionOp::print(OpAsmPrinter &p) {
912 // TODO: Unify printing code with linalg ops.
913 auto attrNames = getTraitAttrNames();
914 llvm::StringSet<> traitAttrsSet;
915 traitAttrsSet.insert_range(attrNames);
917 for (auto attr : (*this)->getAttrs()) {
918 if (attr.getName() == getIteratorTypesAttrName()) {
919 auto iteratorTypes =
920 llvm::cast<ArrayAttr>(attr.getValue())
921 .getAsValueRange<IteratorTypeAttr, IteratorType>();
922 // Convert IteratorType enums into the string representation. This is
923 // needed, because tests still use the old format when 'iterator_types'
924 // attribute is represented as an array of strings.
925 // TODO: Remove this conversion once tests are fixed.
926 SmallVector<Attribute> iteratorTypeNames =
927 llvm::map_to_vector(iteratorTypes, [&](IteratorType t) -> Attribute {
928 return StringAttr::get(getContext(), stringifyIteratorType(t));
929 });
930
931 attrs.emplace_back(getIteratorTypesAttrName(),
932 ArrayAttr::get(getContext(), iteratorTypeNames));
933 } else if (traitAttrsSet.count(attr.getName().strref()) > 0) {
934 // Omit fastmath when it equals the default (none) to keep output clean.
935 if (attr.getName() == getFastmathAttrName() &&
936 llvm::cast<arith::FastMathFlagsAttr>(attr.getValue()).getValue() ==
937 arith::FastMathFlags::none)
938 continue;
939 attrs.push_back(attr);
940 }
941 }
942
943 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
944 p << " " << dictAttr << " " << getLhs() << ", ";
945 p << getRhs() << ", " << getAcc();
946
947 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
948 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
949 << getResultType();
950}
951
952static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
953 const std::vector<std::pair<int64_t, int64_t>> &map) {
954 for (auto &dimPair : map) {
955 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
956 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
957 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
958 return false;
959 }
960 return true;
961}
962
963static LogicalResult verifyOutputShape(
964 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
965 Type resType,
966 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
967 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
968 DenseSet<int64_t> lhsContractingDimSet;
969 DenseSet<int64_t> rhsContractingDimSet;
970 for (auto &dimPair : contractingDimMap) {
971 lhsContractingDimSet.insert(dimPair.first);
972 rhsContractingDimSet.insert(dimPair.second);
973 }
974 DenseSet<int64_t> rhsBatchDimSet(llvm::from_range,
975 llvm::make_second_range(batchDimMap));
976
977 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
978 SmallVector<int64_t, 4> expectedResultDims;
979 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
980 if (lhsContractingDimSet.count(i) > 0)
981 continue;
982 expectedResultDims.push_back(lhsType.getDimSize(i));
983 }
984
985 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
986 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
987 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
988 continue;
989 expectedResultDims.push_back(rhsType.getDimSize(i));
990 }
991
992 // Verify 'expectedResultDims'.
993 if (expectedResultDims.empty()) {
994 // No batch or free dimension implies a scalar result.
995 if (llvm::isa<VectorType>(resType) || llvm::isa<VectorType>(accType))
996 return op.emitOpError("invalid accumulator/result vector shape");
997 } else {
998 // At least one batch or free dimension implies a vector result.
999 auto resVectorType = llvm::dyn_cast<VectorType>(resType);
1000 auto accVectorType = llvm::dyn_cast<VectorType>(accType);
1001 if (!resVectorType || !accVectorType)
1002 return op.emitOpError("invalid accumulator/result vector shape");
1003
1004 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
1005 // types fully define the result vector type. This assumes the affine maps
1006 // are well-formed, which must have been verified already.
1007 MLIRContext *ctx = op.getContext();
1008 AffineMap lhsMap = op.getIndexingMapsArray()[0];
1009 AffineMap rhsMap = op.getIndexingMapsArray()[1];
1010 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
1011 return op.emitOpError(
1012 "expected all dimensions to be either a LHS or a RHS dimension");
1013 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
1014 for (auto pair :
1015 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
1016 VectorType v = pair.first;
1017 auto map = pair.second;
1018 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
1019 unsigned pos = map.getDimPosition(idx);
1020 if (!extents[pos])
1021 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
1022 }
1023 }
1024 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
1025 return op.emitOpError("expected all dimensions to get an extent as "
1026 "either a LHS or a RHS dimension");
1027
1028 AffineMap resMap = op.getIndexingMapsArray()[2];
1029 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
1030 /*symbolCount=*/0, extents, ctx);
1031 // Compose the resMap with the extentsMap, which is a constant map.
1032 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
1033 assert(llvm::all_of(expectedMap.getResults(),
1034 llvm::IsaPred<AffineConstantExpr>) &&
1035 "expected constant extent along all dimensions.");
1036 // Extract the expected shape and build the type.
1037 auto expectedShape =
1038 llvm::map_to_vector<4>(expectedMap.getResults(), [](AffineExpr e) {
1039 return cast<AffineConstantExpr>(e).getValue();
1040 });
1041 auto expected =
1042 VectorType::get(expectedShape, resVectorType.getElementType(),
1043 resVectorType.getScalableDims());
1044 if (resVectorType != expected || accVectorType != expected)
1045 return op.emitOpError(
1046 "invalid accumulator/result vector shape, expected: ")
1047 << expected;
1048 }
1049 return success();
1050}
1051
1052LogicalResult ContractionOp::verify() {
1053 VectorType lhsType = getLhsType();
1054 VectorType rhsType = getRhsType();
1055 Type accType = getAccType();
1056 Type resType = getResultType();
1057
1058 if (llvm::isa<IntegerType>(lhsType.getElementType())) {
1059 if (!lhsType.getElementType().isSignlessInteger())
1060 return emitOpError("only supports signless integer types");
1061 }
1062
1063 // Verify that an indexing map was specified for each vector operand.
1064 if (getIndexingMapsArray().size() != 3)
1065 return emitOpError("expected an indexing map for each vector operand");
1066
1067 // Verify that each index map has 'numIterators' inputs, no symbols, and
1068 // that the number of map outputs equals the rank of its associated
1069 // vector operand.
1070 unsigned numIterators = getIteratorTypes().getValue().size();
1071 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1072 auto index = it.index();
1073 auto map = it.value();
1074 if (map.getNumSymbols() != 0)
1075 return emitOpError("expected indexing map ")
1076 << index << " to have no symbols";
1077 auto vectorType = llvm::dyn_cast<VectorType>(getOperand(index).getType());
1078 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
1079 // Verify that the map has the right number of inputs, outputs, and indices.
1080 // This also correctly accounts for (..) -> () for rank-0 results.
1081 if (map.getNumDims() != numIterators)
1082 return emitOpError("expected indexing map ")
1083 << index << " to have " << numIterators << " number of inputs";
1084 if (map.getNumResults() != rank)
1085 return emitOpError("expected indexing map ")
1086 << index << " to have " << rank << " number of outputs";
1087 if (!map.isProjectedPermutation())
1088 return emitOpError("expected indexing map ")
1089 << index << " to be a projected permutation of its inputs";
1090 }
1091
1092 auto contractingDimMap = getContractingDimMap();
1093 auto batchDimMap = getBatchDimMap();
1094
1095 // Verify at least one contracting dimension pair was specified.
1096 if (contractingDimMap.empty())
1097 return emitOpError("expected at least one contracting dimension pair");
1098
1099 // Verify contracting dimension map was properly constructed.
1100 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
1101 return emitOpError("invalid contracting dimension map");
1102
1103 // Verify batch dimension map was properly constructed.
1104 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
1105 return emitOpError("invalid batch dimension map");
1106
1107 // Verify 'accType' and 'resType' shape.
1108 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
1109 contractingDimMap, batchDimMap)))
1110 return failure();
1111
1112 if (!getKindAttr()) {
1113 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
1114 "'vector.kind<add>')");
1115 }
1116
1117 // Verify supported combining kind.
1118 auto vectorType = llvm::dyn_cast<VectorType>(resType);
1119 auto elementType = vectorType ? vectorType.getElementType() : resType;
1120 if (!isSupportedCombiningKind(getKind(), elementType))
1121 return emitOpError("unsupported contraction type");
1122
1123 // Delayed calling of IndexingMapOpInterface::verifyImpl.
1124 return cast<IndexingMapOpInterface>(this->getOperation()).verifyImpl();
1125}
1126
1127// MaskableOpInterface methods.
1128
1129/// Returns the mask type expected by this operation. Mostly used for
1130/// verification purposes. It requires the operation to be vectorized."
1131Type ContractionOp::getExpectedMaskType() {
1132 auto indexingMaps = this->getIndexingMapsArray();
1133 AffineMap lhsIdxMap = indexingMaps[0];
1134 AffineMap rhsIdxMap = indexingMaps[1];
1135 VectorType lhsType = this->getLhsType();
1136 VectorType rhsType = this->getRhsType();
1137
1138 unsigned numVecDims = lhsIdxMap.getNumDims();
1139 SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
1140 SmallVector<bool> maskShapeScalableDims(numVecDims, false);
1141
1142 // Using the information in the indexing maps, extract the size of each
1143 // dimension in the vector.contract operation from the two input operands.
1144 for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
1145 maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1146 maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
1147 lhsType.getScalableDims()[dimIdx];
1148 }
1149 for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
1150 maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
1151 maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
1152 rhsType.getScalableDims()[dimIdx];
1153 }
1154
1155 assert(ShapedType::isStaticShape(maskShape) &&
1156 "Mask shape couldn't be computed");
1157
1158 return VectorType::get(maskShape,
1159 IntegerType::get(lhsType.getContext(), /*width=*/1),
1160 maskShapeScalableDims);
1161}
1162
1163SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
1164 return SmallVector<StringRef>{getIndexingMapsAttrName(),
1165 getIteratorTypesAttrName(), getKindAttrName(),
1166 getFastmathAttrName()};
1167}
1168
1170 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
1171 if (targetExpr == map.getResult(i))
1172 return i;
1173 return -1;
1174}
1175
1176static std::vector<std::pair<int64_t, int64_t>>
1177getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
1178 IteratorType targetIteratorType, MLIRContext *context) {
1179 std::vector<std::pair<int64_t, int64_t>> dimMap;
1180 for (const auto &it : llvm::enumerate(iteratorTypes)) {
1181 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1182 if (iteratorType != targetIteratorType)
1183 continue;
1184 // Search lhs/rhs map results for 'targetExpr'.
1185 auto targetExpr = getAffineDimExpr(it.index(), context);
1186 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
1187 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
1188 if (lhsDim >= 0 && rhsDim >= 0)
1189 dimMap.emplace_back(lhsDim, rhsDim);
1190 }
1191 return dimMap;
1192}
1193
1194void ContractionOp::getIterationBounds(
1195 SmallVectorImpl<int64_t> &iterationBounds) {
1196 auto lhsShape = getLhsType().getShape();
1197 auto resVectorType = llvm::dyn_cast<VectorType>(getResultType());
1198 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1199 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
1200 // Search lhs/rhs map results for 'targetExpr'.
1201 auto targetExpr = getAffineDimExpr(it.index(), getContext());
1202 auto iteratorType = llvm::cast<IteratorTypeAttr>(it.value()).getValue();
1203 if (iteratorType == IteratorType::reduction) {
1204 // Get reduction dim size from lhs shape (same size in rhsShape).
1205 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
1206 assert(lhsDimIndex >= 0);
1207 iterationBounds.push_back(lhsShape[lhsDimIndex]);
1208 continue;
1209 }
1210 // Get parallel dimension size from result shape.
1211 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
1212 assert(resDimIndex >= 0);
1213 assert(resVectorType != nullptr);
1214 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
1215 }
1216}
1217
1218void ContractionOp::getIterationIndexMap(
1219 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
1220 unsigned numMaps = getIndexingMapsArray().size();
1221 iterationIndexMap.resize(numMaps);
1222 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
1223 auto index = it.index();
1224 auto map = it.value();
1225 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1226 auto dim = cast<AffineDimExpr>(map.getResult(i));
1227 iterationIndexMap[index][dim.getPosition()] = i;
1228 }
1229 }
1230}
1231
1232std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
1233 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1234 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
1235 getContext());
1236}
1237
1238std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
1239 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
1240 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
1241 getContext());
1242}
1243
1244std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
1246 getIterationBounds(shape);
1247 return shape;
1248}
1249
1250/// Return a fused vector::ContractionOp which represents a patterns such as:
1251///
1252/// ```mlir
1253/// %c0 = vector.constant 0: ...
1254/// %c = vector.contract %a, %b, %c0: ...
1255/// %e = add %c, %d: ...
1256/// ```
1257///
1258/// by:
1259///
1260/// ```mlir
1261/// %e = vector.contract %a, %b, %d: ...
1262/// ```
1263///
1264/// Return null if the canonicalization does not apply.
1265// TODO: This should be a folding of Add into Contract in core but while they
1266// live in different dialects, it is not possible without unnatural
1267// dependencies.
1268template <typename AddOpType>
1269struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
1270 using OpRewritePattern<AddOpType>::OpRewritePattern;
1271
1272 LogicalResult matchAndRewrite(AddOpType addOp,
1273 PatternRewriter &rewriter) const override {
1274 auto canonicalize = [&](Value maybeContraction,
1275 Value otherOperand) -> vector::ContractionOp {
1276 vector::ContractionOp contractionOp =
1277 dyn_cast_or_null<vector::ContractionOp>(
1278 maybeContraction.getDefiningOp());
1279 if (!contractionOp)
1280 return vector::ContractionOp();
1281 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
1282 contractionOp.getAcc().getDefiningOp())) {
1283 if (maybeZero.getValue() ==
1284 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
1285 IRMapping bvm;
1286 bvm.map(contractionOp.getAcc(), otherOperand);
1287 auto newContraction =
1288 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
1289 rewriter.replaceOp(addOp, newContraction.getResult());
1290 return newContraction;
1291 }
1292 }
1293 return vector::ContractionOp();
1294 };
1295
1296 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
1297 vector::ContractionOp contract = canonicalize(a, b);
1298 contract = contract ? contract : canonicalize(b, a);
1299 return contract ? success() : failure();
1300 }
1301};
1302
1303void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1304 MLIRContext *context) {
1307}
1308
1309// Returns `true` if `index` is either within [0, maxIndex) or equal to
1310// `poisonValue`.
1312 int64_t maxIndex) {
1313 return index == poisonValue || (index >= 0 && index < maxIndex);
1314}
1315
1316//===----------------------------------------------------------------------===//
1317// ExtractOp
1318//===----------------------------------------------------------------------===//
1319
1320void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1321 SetIntRangeFn setResultRanges) {
1322 setResultRanges(getResult(), argRanges.front());
1323}
1324
1325void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1326 Value source) {
1327 auto vectorTy = cast<VectorType>(source.getType());
1328 build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1329}
1330
1331void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1332 Value source, int64_t position) {
1333 build(builder, result, source, ArrayRef<int64_t>{position});
1334}
1335
1336void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1337 Value source, OpFoldResult position) {
1338 build(builder, result, source, ArrayRef<OpFoldResult>{position});
1339}
1340
1341void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1342 Value source, ArrayRef<int64_t> position) {
1343 build(builder, result, source, /*dynamic_position=*/ArrayRef<Value>(),
1344 builder.getDenseI64ArrayAttr(position));
1345}
1346
1347void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1348 Value source, ArrayRef<OpFoldResult> position) {
1349 SmallVector<int64_t> staticPos;
1350 SmallVector<Value> dynamicPos;
1351 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
1352 build(builder, result, source, dynamicPos,
1353 builder.getDenseI64ArrayAttr(staticPos));
1354}
1355
1356LogicalResult
1357ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1358 ExtractOp::Adaptor adaptor,
1359 SmallVectorImpl<Type> &inferredReturnTypes) {
1360 auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
1361 if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
1362 vectorType.getRank()) {
1363 inferredReturnTypes.push_back(vectorType.getElementType());
1364 } else {
1365 auto n = std::min<size_t>(adaptor.getStaticPosition().size(),
1366 vectorType.getRank());
1367 inferredReturnTypes.push_back(VectorType::get(
1368 vectorType.getShape().drop_front(n), vectorType.getElementType(),
1369 vectorType.getScalableDims().drop_front(n)));
1370 }
1371 return success();
1372}
1373
1374LogicalResult vector::ExtractOp::verify() {
1375 if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
1376 if (resTy.getRank() == 0)
1377 return emitError(
1378 "expected a scalar instead of a 0-d vector as the result type");
1379
1380 // Note: This check must come before getMixedPosition() to prevent a crash.
1381 auto dynamicMarkersCount =
1382 llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
1383 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicPosition().size())
1384 return emitOpError(
1385 "mismatch between dynamic and static positions (kDynamic marker but no "
1386 "corresponding dynamic position) -- this can only happen due to an "
1387 "incorrect fold/rewrite");
1388 auto position = getMixedPosition();
1389 if (position.size() > static_cast<unsigned>(getSourceVectorType().getRank()))
1390 return emitOpError(
1391 "expected position attribute of rank no greater than vector rank");
1392 for (auto [idx, pos] : llvm::enumerate(position)) {
1393 if (auto attr = dyn_cast<Attribute>(pos)) {
1394 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
1396 constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) {
1397 return emitOpError("expected position attribute #")
1398 << (idx + 1)
1399 << " to be a non-negative integer smaller than the "
1400 "corresponding vector dimension or poison (-1)";
1401 }
1402 }
1403 }
1404 return success();
1405}
1406
1407template <typename IntType>
1409 return llvm::map_to_vector<4>(
1410 arrayAttr.getAsRange<IntegerAttr>(),
1411 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); });
1412}
1413
1414/// Fold the result of chains of ExtractOp in place by simply concatenating the
1415/// positions.
1416static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1417 if (!extractOp.getSource().getDefiningOp<ExtractOp>())
1418 return failure();
1419
1420 // TODO: Canonicalization for dynamic position not implemented yet.
1421 if (extractOp.hasDynamicPosition())
1422 return failure();
1423
1424 SmallVector<int64_t> globalPosition;
1425 ExtractOp currentOp = extractOp;
1426 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1427 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1428 while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
1429 currentOp = nextOp;
1430 // TODO: Canonicalization for dynamic position not implemented yet.
1431 if (currentOp.hasDynamicPosition())
1432 return failure();
1433 ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
1434 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1435 }
1436 extractOp.setOperand(0, currentOp.getSource());
1437 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1438 OpBuilder b(extractOp.getContext());
1439 std::reverse(globalPosition.begin(), globalPosition.end());
1440 extractOp.setStaticPosition(globalPosition);
1441 return success();
1442}
1443
1444namespace {
1445/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1446/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1447/// Compose TransposeOp permutations as we walk back.
1448/// This helper class keeps an updated extraction position `extractPosition`
1449/// with extra trailing sentinels.
1450/// The sentinels encode the internal transposition status of the result vector.
1451/// As we iterate, extractPosition is permuted and updated.
1452class ExtractFromInsertTransposeChainState {
1453public:
1454 ExtractFromInsertTransposeChainState(ExtractOp e);
1455
1456 /// Iterate over producing insert and transpose ops until we find a fold.
1457 Value fold();
1458
1459private:
1460 /// Return true if the vector at position `a` is contained within the vector
1461 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1462 /// is a prefix of `b`.
1463 template <typename ContainerA, typename ContainerB>
1464 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1465 return a.size() <= b.size() &&
1466 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1467 }
1468
1469 /// Return true if the vector at position `a` intersects the vector at
1470 /// position `b`. Under insert/extract semantics, this is the same as equality
1471 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1472 /// Comparison is on the common prefix (i.e. zip).
1473 template <typename ContainerA, typename ContainerB>
1474 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1475 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1476 if (elemA < 0 || elemB < 0)
1477 continue;
1478 if (elemA != elemB)
1479 return false;
1480 }
1481 return true;
1482 }
1483
1484 /// Folding is only possible in the absence of an internal permutation in the
1485 /// result vector.
1486 bool canFold() {
1487 return (sentinels == ArrayRef(extractPosition).drop_front(extractedRank));
1488 }
1489
1490 // Helper to get the next defining op of interest.
1491 void updateStateForNextIteration(Value v) {
1492 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1493 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1494 };
1495
1496 // Case 1. If we hit a transpose, just compose the map and iterate.
1497 // Invariant: insert + transpose do not change rank, we can always compose.
1498 LogicalResult handleTransposeOp();
1499
1500 // Case 2: the insert position matches extractPosition exactly, early return.
1501 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1502
1503 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1504 /// portion of the source of the insert.
1505 /// Example:
1506 /// ```
1507 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1508 /// // extractPosition == [1, 2, 3]
1509 /// %ext = vector.extract %ins[1, 0]: vector<5> from vector<3x4x5>
1510 /// // can fold to vector.extract %source[0, 3]
1511 /// %ext = vector.extract %source[3]: vector<6> from vector<5x6>
1512 /// ```
1513 /// To traverse through %source, we need to set the leading dims to 0 and
1514 /// drop the extra leading dims.
1515 /// This method updates the internal state.
1516 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1517
1518 /// Try to fold in place to extract(source, extractPosition) and return the
1519 /// folded result. Return null if folding is not possible (e.g. due to an
1520 /// internal transposition in the result).
1521 Value tryToFoldExtractOpInPlace(Value source);
1522
1523 ExtractOp extractOp;
1524 int64_t vectorRank;
1525 int64_t extractedRank;
1526
1527 InsertOp nextInsertOp;
1528 TransposeOp nextTransposeOp;
1529
1530 /// Sentinel values that encode the internal permutation status of the result.
1531 /// They are set to (-1, ... , -k) at the beginning and appended to
1532 /// `extractPosition`.
1533 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1534 /// ensure that there is no internal transposition.
1535 /// Internal transposition cannot be accounted for with a folding pattern.
1536 // TODO: We could relax the internal transposition with an extra transposition
1537 // operation in a future canonicalizer.
1538 SmallVector<int64_t> sentinels;
1539 SmallVector<int64_t> extractPosition;
1540};
1541} // namespace
1542
1543ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1544 ExtractOp e)
1545 : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
1546 extractedRank(extractOp.getNumIndices()) {
1547 assert(vectorRank >= extractedRank && "Extracted position overflow");
1548 sentinels.reserve(vectorRank - extractedRank);
1549 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1550 sentinels.push_back(-(i + 1));
1551 extractPosition.assign(extractOp.getStaticPosition().begin(),
1552 extractOp.getStaticPosition().end());
1553 llvm::append_range(extractPosition, sentinels);
1554}
1555
1556// Case 1. If we hit a transpose, just compose the map and iterate.
1557// Invariant: insert + transpose do not change rank, we can always compose.
1558LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1559 // TODO: Canonicalization for dynamic position not implemented yet.
1560 if (extractOp.hasDynamicPosition())
1561 return failure();
1562
1563 if (!nextTransposeOp)
1564 return failure();
1566 nextTransposeOp.getPermutation(), extractOp.getContext()));
1568 return success();
1569}
1570
1571// Case 2: the insert position matches extractPosition exactly, early return.
1572LogicalResult
1573ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1574 Value &res) {
1575 // TODO: Canonicalization for dynamic position not implemented yet.
1576 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1577 return failure();
1578
1579 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1580 if (insertedPos != llvm::ArrayRef(extractPosition).take_front(extractedRank))
1581 return failure();
1582 // Case 2.a. early-exit fold.
1583 res = nextInsertOp.getValueToStore();
1584 // Case 2.b. if internal transposition is present, canFold will be false.
1585 return success(canFold());
1586}
1587
1588/// Case 3: if inserted position is a prefix of extractPosition,
1589/// extract a portion of the source of the insertion.
1590/// This method updates the internal state.
1591LogicalResult
1592ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1593 // TODO: Canonicalization for dynamic position not implemented yet.
1594 if (extractOp.hasDynamicPosition() || nextInsertOp.hasDynamicPosition())
1595 return failure();
1596
1597 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1598 if (!isContainedWithin(insertedPos, extractPosition))
1599 return failure();
1600 // Set leading dims to zero.
1601 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1602 // Drop extra leading dims.
1603 extractPosition.erase(extractPosition.begin(),
1604 extractPosition.begin() + insertedPos.size());
1605 extractedRank = extractPosition.size() - sentinels.size();
1606 // Case 3.a. early-exit fold (break and delegate to post-while path).
1607 res = nextInsertOp.getValueToStore();
1608 // Case 3.b. if internal transposition is present, canFold will be false.
1609 return success();
1610}
1611
1612/// Try to fold in place to extract(source, extractPosition) and return the
1613/// folded result. Return null if folding is not possible (e.g. due to an
1614/// internal transposition in the result).
1615Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1616 Value source) {
1617 // TODO: Canonicalization for dynamic position not implemented yet.
1618 if (extractOp.hasDynamicPosition())
1619 return Value();
1620
1621 // If we can't fold (either internal transposition, or nothing to fold), bail.
1622 bool nothingToFold = (source == extractOp.getSource());
1623 if (nothingToFold || !canFold())
1624 return Value();
1625
1626 // Otherwise, fold by updating the op inplace and return its result.
1627 OpBuilder b(extractOp.getContext());
1628 extractOp.setStaticPosition(
1629 ArrayRef(extractPosition).take_front(extractedRank));
1630 extractOp.getSourceMutable().assign(source);
1631 return extractOp.getResult();
1632}
1633
1634/// Iterate over producing insert and transpose ops until we find a fold.
1635Value ExtractFromInsertTransposeChainState::fold() {
1636 // TODO: Canonicalization for dynamic position not implemented yet.
1637 if (extractOp.hasDynamicPosition())
1638 return Value();
1639
1640 Value valueToExtractFrom = extractOp.getSource();
1641 updateStateForNextIteration(valueToExtractFrom);
1642 while (nextInsertOp || nextTransposeOp) {
1643 // Case 1. If we hit a transpose, just compose the map and iterate.
1644 // Invariant: insert + transpose do not change rank, we can always compose.
1645 if (succeeded(handleTransposeOp())) {
1646 valueToExtractFrom = nextTransposeOp.getVector();
1647 updateStateForNextIteration(valueToExtractFrom);
1648 continue;
1649 }
1650
1651 Value result;
1652 // Case 2: the position match exactly.
1653 if (succeeded(handleInsertOpWithMatchingPos(result)))
1654 return result;
1655
1656 // Case 3: if the inserted position is a prefix of extractPosition, we can
1657 // just extract a portion of the source of the insert.
1658 if (succeeded(handleInsertOpWithPrefixPos(result)))
1659 return tryToFoldExtractOpInPlace(result);
1660
1661 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1662 // values. This is a more difficult case and we bail.
1663 ArrayRef<int64_t> insertedPos = nextInsertOp.getStaticPosition();
1664 if (isContainedWithin(extractPosition, insertedPos) ||
1665 intersectsWhereNonNegative(extractPosition, insertedPos))
1666 return Value();
1667
1668 // Case 5: No intersection, we forward the extract to insertOp.dest().
1669 valueToExtractFrom = nextInsertOp.getDest();
1670 updateStateForNextIteration(valueToExtractFrom);
1671 }
1672 // If after all this we can fold, go for it.
1673 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1674}
1675
1676/// Returns true if the operation has a 0-D vector type operand or result.
1678 auto hasZeroDimVectorType = [](Type type) -> bool {
1679 auto vecType = dyn_cast<VectorType>(type);
1680 return vecType && vecType.getRank() == 0;
1681 };
1682
1683 return llvm::any_of(op->getOperandTypes(), hasZeroDimVectorType) ||
1684 llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
1685}
1686
1687/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
1688/// considered to be 'broadcastlike'.
1689static bool isBroadcastLike(Operation *op) {
1690 if (isa<BroadcastOp>(op))
1691 return true;
1692
1693 auto shapeCast = dyn_cast<ShapeCastOp>(op);
1694 if (!shapeCast)
1695 return false;
1696
1697 // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1698 // Checking that the destination shape has a prefix of 1s is not sufficient,
1699 // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1700 // is that the source shape is a suffix of the destination shape.
1701 VectorType srcType = shapeCast.getSourceVectorType();
1702 ArrayRef<int64_t> srcShape = srcType.getShape();
1703 uint64_t srcRank = srcType.getRank();
1704 ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1705 return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1706}
1707
1708/// Fold extract(broadcast(X)) to either extract(X) or just X.
1709///
1710/// Example:
1711///
1712/// broadcast extract [1][2]
1713/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1714///
1715/// becomes
1716/// extract [1]
1717/// (3,4) -------------------------------------> (4)
1718///
1719///
1720/// The variable names used in this implementation correspond to the above
1721/// shapes as,
1722///
1723/// - (3, 4) is `input` shape.
1724/// - (2, 3, 4) is `broadcast` shape.
1725/// - (4) is `extract` shape.
1726///
1727/// This folding is possible when the suffix of `input` shape is the same as
1728/// `extract` shape.
1729static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1730
1731 Operation *defOp = extractOp.getSource().getDefiningOp();
1732 if (!defOp || !isBroadcastLike(defOp))
1733 return Value();
1734
1735 Value input = defOp->getOperand(0);
1736
1737 // Replace extract(broadcast(X)) with X
1738 if (extractOp.getType() == input.getType())
1739 return input;
1740
1741 // Get required types and ranks in the chain
1742 // input -> broadcast -> extract
1743 // (scalars are treated as rank-0).
1744 auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1745 auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1746 unsigned inputRank = inputType ? inputType.getRank() : 0;
1747 unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1748 unsigned extractRank = extractType ? extractType.getRank() : 0;
1749
1750 // Cannot do without the broadcast if overall the rank increases.
1751 if (extractRank > inputRank)
1752 return Value();
1753
1754 // The above condition guarantees that input is a vector.
1755 assert(inputType && "input must be a vector type because of previous checks");
1756 ArrayRef<int64_t> inputShape = inputType.getShape();
1757
1758 // In the case where there is a broadcast dimension in the suffix, it is not
1759 // possible to replace extract(broadcast(X)) with extract(X). Example:
1760 //
1761 // broadcast extract
1762 // (1) --------> (3,4) ------> (4)
1763 if (extractType &&
1764 extractType.getShape() != inputShape.take_back(extractRank))
1765 return Value();
1766
1767 // Replace extract(broadcast(X)) with extract(X).
1768 // First, determine the new extraction position.
1769 unsigned deltaOverall = inputRank - extractRank;
1770 unsigned deltaBroadcast = broadcastRank - inputRank;
1771 SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1772 SmallVector<OpFoldResult> newPositions(deltaOverall);
1773 IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1774 for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1775 newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1776 }
1777 auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
1778 extractOp->setOperands(
1779 llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
1780 extractOp.setStaticPosition(staticPos);
1781 return extractOp.getResult();
1782}
1783
1784/// Fold extractOp coming from ShuffleOp.
1785///
1786/// Example:
1787///
1788/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1789/// : vector<8xf32>, vector<8xf32>
1790/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1791/// ->
1792/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1793///
1794static Value foldExtractFromShuffle(ExtractOp extractOp) {
1795 // Dynamic positions are not folded as the resulting code would be more
1796 // complex than the input code.
1797 if (extractOp.hasDynamicPosition())
1798 return Value();
1799
1800 auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
1801 if (!shuffleOp)
1802 return Value();
1803
1804 // TODO: 0-D or multi-dimensional vectors not supported yet.
1805 if (shuffleOp.getResultVectorType().getRank() != 1)
1806 return Value();
1807
1808 int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1809 auto shuffleMask = shuffleOp.getMask();
1810 int64_t extractIdx = extractOp.getStaticPosition()[0];
1811 int64_t shuffleIdx = shuffleMask[extractIdx];
1812
1813 // Find the shuffled vector to extract from based on the shuffle index.
1814 if (shuffleIdx < inputVecSize) {
1815 extractOp.setOperand(0, shuffleOp.getV1());
1816 extractOp.setStaticPosition({shuffleIdx});
1817 } else {
1818 extractOp.setOperand(0, shuffleOp.getV2());
1819 extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1820 }
1821
1822 return extractOp.getResult();
1823}
1824
1825// Fold extractOp with source coming from ShapeCast op.
1826static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1827 // TODO: Canonicalization for dynamic position not implemented yet.
1828 if (extractOp.hasDynamicPosition())
1829 return Value();
1830
1831 auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
1832 if (!shapeCastOp)
1833 return Value();
1834
1835 // Get the nth dimension size starting from lowest dimension.
1836 auto getDimReverse = [](VectorType type, int64_t n) {
1837 return type.getShape().take_back(n + 1).front();
1838 };
1839 int64_t destinationRank =
1840 llvm::isa<VectorType>(extractOp.getType())
1841 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1842 : 0;
1843 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1844 return Value();
1845 if (destinationRank > 0) {
1846 auto destinationType =
1847 llvm::cast<VectorType>(extractOp.getResult().getType());
1848 for (int64_t i = 0; i < destinationRank; i++) {
1849 // The lowest dimension of the destination must match the lowest
1850 // dimension of the shapecast op source.
1851 // TODO: This case could be support in a canonicalization pattern.
1852 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1853 getDimReverse(destinationType, i))
1854 return Value();
1855 }
1856 }
1857 // Extract the strides associated with the extract op vector source. Then use
1858 // this to calculate a linearized position for the extract.
1859 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1860 std::reverse(extractedPos.begin(), extractedPos.end());
1862 int64_t stride = 1;
1863 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1864 strides.push_back(stride);
1865 stride *=
1866 getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
1867 }
1868
1869 int64_t position = linearize(extractedPos, strides);
1870 // Then extract the strides associated to the shapeCast op vector source and
1871 // delinearize the position using those strides.
1872 SmallVector<int64_t, 4> newStrides;
1873 int64_t numDimension =
1874 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1875 stride = 1;
1876 for (int64_t i = 0; i < numDimension; i++) {
1877 newStrides.push_back(stride);
1878 stride *=
1879 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1880 }
1881 std::reverse(newStrides.begin(), newStrides.end());
1882 SmallVector<int64_t, 4> newPosition = delinearize(position, newStrides);
1883 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1884 OpBuilder b(extractOp.getContext());
1885 extractOp.setStaticPosition(newPosition);
1886 extractOp.setOperand(0, shapeCastOp.getSource());
1887 return extractOp.getResult();
1888}
1889
1890/// Fold an ExtractOp from ExtractStridedSliceOp.
1891static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1892 // TODO: Canonicalization for dynamic position not implemented yet.
1893 if (extractOp.hasDynamicPosition())
1894 return Value();
1895
1896 auto extractStridedSliceOp =
1897 extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
1898 if (!extractStridedSliceOp)
1899 return Value();
1900
1901 // 0-D vectors not supported.
1902 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1903 if (hasZeroDimVectors(extractStridedSliceOp))
1904 return Value();
1905
1906 // Return if 'extractStridedSliceOp' has non-unit strides.
1907 if (extractStridedSliceOp.hasNonUnitStrides())
1908 return Value();
1909
1910 // Trim offsets for dimensions fully extracted.
1911 auto sliceOffsets =
1912 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1913 while (!sliceOffsets.empty()) {
1914 size_t lastOffset = sliceOffsets.size() - 1;
1915 if (sliceOffsets.back() != 0 ||
1916 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1917 extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
1918 break;
1919 sliceOffsets.pop_back();
1920 }
1921 unsigned destinationRank = 0;
1922 if (auto vecType = llvm::dyn_cast<VectorType>(extractOp.getType()))
1923 destinationRank = vecType.getRank();
1924 // The dimensions of the result need to be untouched by the
1925 // extractStridedSlice op.
1926 if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
1927 sliceOffsets.size())
1928 return Value();
1929
1930 SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
1931 assert(extractedPos.size() >= sliceOffsets.size());
1932 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1933 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1934 extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
1935
1936 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1937 OpBuilder b(extractOp.getContext());
1938 extractOp.setStaticPosition(extractedPos);
1939 return extractOp.getResult();
1940}
1941
1942/// Fold extract_op fed from a chain of insertStridedSlice ops.
1943static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
1944 // TODO: Canonicalization for dynamic position not implemented yet.
1945 if (extractOp.hasDynamicPosition())
1946 return Value();
1947
1948 int64_t destinationRank =
1949 llvm::isa<VectorType>(extractOp.getType())
1950 ? llvm::cast<VectorType>(extractOp.getType()).getRank()
1951 : 0;
1952 auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
1953 if (!insertOp)
1954 return Value();
1955
1956 // 0-D vectors not supported.
1957 assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
1958 if (hasZeroDimVectors(insertOp))
1959 return Value();
1960
1961 while (insertOp) {
1962 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1963 insertOp.getSourceVectorType().getRank();
1964 if (destinationRank > insertOp.getSourceVectorType().getRank())
1965 return Value();
1966 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1967 ArrayRef<int64_t> extractOffsets = extractOp.getStaticPosition();
1968
1969 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1970 return llvm::cast<IntegerAttr>(attr).getInt() != 1;
1971 }))
1972 return Value();
1973 bool disjoint = false;
1974 SmallVector<int64_t, 4> offsetDiffs;
1975 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1976 int64_t start = insertOffsets[dim];
1977 int64_t size =
1978 (dim < insertRankDiff)
1979 ? 1
1980 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1981 int64_t end = start + size;
1982 int64_t offset = extractOffsets[dim];
1983 // Check if the start of the extract offset is in the interval inserted.
1984 if (start <= offset && offset < end) {
1985 if (dim >= insertRankDiff)
1986 offsetDiffs.push_back(offset - start);
1987 continue;
1988 }
1989 disjoint = true;
1990 break;
1991 }
1992 // The extract element chunk overlap with the vector inserted.
1993 if (!disjoint) {
1994 // If any of the inner dimensions are only partially inserted we have a
1995 // partial overlap.
1996 int64_t srcRankDiff =
1997 insertOp.getSourceVectorType().getRank() - destinationRank;
1998 for (int64_t i = 0; i < destinationRank; i++) {
1999 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
2000 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
2001 insertRankDiff))
2002 return Value();
2003 }
2004 extractOp.getSourceMutable().assign(insertOp.getValueToStore());
2005 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2006 OpBuilder b(extractOp.getContext());
2007 extractOp.setStaticPosition(offsetDiffs);
2008 return extractOp.getResult();
2009 }
2010 // If the chunk extracted is disjoint from the chunk inserted, keep
2011 // looking in the insert chain.
2012 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2013 }
2014 return Value();
2015}
2016
2017/// Try to fold the extraction of a scalar from a vector defined by
2018/// vector.from_elements. E.g.:
2019///
2020/// %0 = vector.from_elements %a, %b : vector<2xf32>
2021/// %1 = vector.extract %0[0] : f32 from vector<2xf32>
2022/// ==> fold to %a
2023static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
2024 // Dynamic extractions cannot be folded.
2025 if (extractOp.hasDynamicPosition())
2026 return {};
2027
2028 // Look for extract(from_elements).
2029 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2030 if (!fromElementsOp)
2031 return {};
2032
2033 // Scalable vectors are not supported.
2034 auto vecType = llvm::cast<VectorType>(fromElementsOp.getType());
2035 if (vecType.isScalable())
2036 return {};
2037
2038 // Only extractions of scalars are supported.
2039 int64_t rank = vecType.getRank();
2040 ArrayRef<int64_t> indices = extractOp.getStaticPosition();
2041 if (extractOp.getType() != vecType.getElementType())
2042 return {};
2043 assert(static_cast<int64_t>(indices.size()) == rank &&
2044 "unexpected number of indices");
2045
2046 // Compute flattened/linearized index and fold to operand.
2047 int flatIndex = 0;
2048 int stride = 1;
2049 for (int i = rank - 1; i >= 0; --i) {
2050 flatIndex += indices[i] * stride;
2051 stride *= vecType.getDimSize(i);
2052 }
2053 return fromElementsOp.getElements()[flatIndex];
2054}
2055
2056/// If the dynamic indices of `extractOp` or `insertOp` are in fact constants,
2057/// then fold it.
2058template <typename OpType, typename AdaptorType>
2059static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
2060 SmallVectorImpl<Value> &operands) {
2061 std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
2062 OperandRange dynamicPosition = op.getDynamicPosition();
2063 ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2065 if constexpr (std::is_same_v<OpType, ExtractOp>)
2066 vectorShape = op.getSourceVectorType().getShape();
2067 else
2068 vectorShape = op.getDestVectorType().getShape();
2069
2070 // If the dynamic operands is empty, it is returned directly.
2071 if (!dynamicPosition.size())
2072 return {};
2073
2074 // `index` is used to iterate over the `dynamicPosition`.
2075 unsigned index = 0;
2076
2077 // `opChange` is a flag. If it is true, it means to update `op` in place.
2078 bool opChange = false;
2079 for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
2080 if (ShapedType::isStatic(staticPosition[i]))
2081 continue;
2082 Attribute positionAttr = dynamicPositionAttr[index];
2083 Value position = dynamicPosition[index++];
2084 if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2085 int64_t value = attr.getInt();
2086 // Do not fold if the value is out of bounds (-1 signifies a poison
2087 // value rather than OOB index).
2088 if (value >= -1 && value < vectorShape[i]) {
2089 staticPosition[i] = attr.getInt();
2090 opChange = true;
2091 continue;
2092 }
2093 }
2094 operands.push_back(position);
2095 }
2096
2097 if (opChange) {
2098 op.setStaticPosition(staticPosition);
2099 op.getOperation()->setOperands(operands);
2100 // Return the original result to indicate an in-place folding happened.
2101 return op.getResult();
2102 }
2103 return {};
2104}
2105
2106/// Fold an insert or extract operation into an poison value when a poison index
2107/// is found at any dimension of the static position.
2109 ArrayRef<int64_t> staticPos,
2110 int64_t poisonVal) {
2111 if (!is_contained(staticPos, poisonVal))
2112 return {};
2113
2114 return ub::PoisonAttr::get(context);
2115}
2116
2117/// Fold a vector extract from is a poison source.
2119 if (matchPattern(srcAttr, ub::m_Poison()))
2120 return srcAttr;
2121
2122 return {};
2123}
2124
2125/// Fold a vector extract extracting from a DenseElementsAttr.
2127 Attribute srcAttr) {
2128 auto denseAttr = dyn_cast_if_present<DenseElementsAttr>(srcAttr);
2129 if (!denseAttr) {
2130 return {};
2131 }
2132
2133 if (denseAttr.isSplat()) {
2134 Attribute newAttr = denseAttr.getSplatValue<Attribute>();
2135 if (auto vecDstType = dyn_cast<VectorType>(extractOp.getType()))
2136 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
2137 return newAttr;
2138 }
2139
2140 auto vecTy = cast<VectorType>(extractOp.getSourceVectorType());
2141 if (vecTy.isScalable())
2142 return {};
2143
2144 if (extractOp.hasDynamicPosition()) {
2145 return {};
2146 }
2147
2148 // Materializing subsets of a large constant array can generally lead to
2149 // explosion in IR size because of different combination of subsets that
2150 // can exist. However, vector.extract is a restricted form of subset
2151 // extract where you can only extract non-overlapping (or the same) subset for
2152 // a given rank of the subset. Because of this property, the IR size can only
2153 // increase at most by `rank * size(array)` from a single constant array being
2154 // extracted by multiple extracts.
2155
2156 // Calculate the linearized position of the continuous chunk of elements to
2157 // extract.
2158 SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
2159 copy(extractOp.getStaticPosition(), completePositions.begin());
2160 int64_t startPos =
2161 linearize(completePositions, computeStrides(vecTy.getShape()));
2162 auto denseValuesBegin = denseAttr.value_begin<TypedAttr>() + startPos;
2163
2164 TypedAttr newAttr;
2165 if (auto resVecTy = dyn_cast<VectorType>(extractOp.getType())) {
2166 SmallVector<Attribute> elementValues(
2167 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
2168 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
2169 } else {
2170 newAttr = *denseValuesBegin;
2171 }
2172
2173 return newAttr;
2174}
2175
2176OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
2177 // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
2178 // Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
2179 // mismatch).
2180 if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
2181 return getSource();
2182 if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
2183 return res;
2184 // Fold `arith.constant` indices into the `vector.extract` operation.
2185 // Do not stop here as this fold may enable subsequent folds that require
2186 // constant indices.
2187 SmallVector<Value> operands = {getSource()};
2188 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
2189
2190 if (auto res = foldPoisonIndexInsertExtractOp(
2191 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
2192 return res;
2193 if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
2194 return res;
2195 if (succeeded(foldExtractOpFromExtractChain(*this)))
2196 return getResult();
2197 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
2198 return res;
2199 if (auto res = foldExtractFromBroadcast(*this))
2200 return res;
2201 if (auto res = foldExtractFromShuffle(*this))
2202 return res;
2203 if (auto res = foldExtractFromShapeCast(*this))
2204 return res;
2205 if (auto val = foldExtractFromExtractStrided(*this))
2206 return val;
2207 if (auto val = foldExtractStridedOpFromInsertChain(*this))
2208 return val;
2209 if (auto val = foldScalarExtractFromFromElements(*this))
2210 return val;
2211
2212 return inplaceFolded;
2213}
2214
2215namespace {
2216
2217// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2218class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2219public:
2220 using Base::Base;
2221
2222 LogicalResult matchAndRewrite(ExtractOp extractOp,
2223 PatternRewriter &rewriter) const override {
2224
2225 Operation *defOp = extractOp.getSource().getDefiningOp();
2226 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2227 if (!defOp || !isBroadcastLike(defOp) || !outType)
2228 return failure();
2229
2230 Value source = defOp->getOperand(0);
2231 if (isBroadcastableTo(source.getType(), outType) !=
2232 BroadcastableToResult::Success)
2233 return failure();
2234
2235 rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
2236 return success();
2237 }
2238};
2239
2240// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
2241class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
2242public:
2243 using Base::Base;
2244
2245 LogicalResult matchAndRewrite(ExtractOp extractOp,
2246 PatternRewriter &rewriter) const override {
2247 auto createMaskOp =
2248 extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
2249 if (!createMaskOp)
2250 return failure();
2251
2252 VectorType extractedMaskType =
2253 llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
2254
2255 if (!extractedMaskType)
2256 return failure();
2257
2258 auto maskOperands = createMaskOp.getOperands();
2259 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2260 VectorType maskType = createMaskOp.getVectorType();
2261
2262 bool containsUnknownDims = false;
2263 bool allFalse = getMaskFormat(createMaskOp) == MaskFormat::AllFalse;
2264
2265 for (size_t dimIdx = 0; !allFalse && dimIdx < extractOpPos.size();
2266 dimIdx++) {
2267 int64_t pos = extractOpPos[dimIdx];
2268 Value operand = maskOperands[dimIdx];
2269 auto constantOp = operand.getDefiningOp<arith::ConstantOp>();
2270 if (!constantOp) {
2271 // Bounds of this dim unknown.
2272 containsUnknownDims = true;
2273 continue;
2274 }
2275
2276 int64_t createMaskBound =
2277 llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
2278
2279 if (pos != ShapedType::kDynamic) {
2280 // If any position is outside the range from the `create_mask`, then the
2281 // extracted mask will be all-false.
2282 allFalse |= pos >= createMaskBound;
2283 } else if (createMaskBound < maskType.getDimSize(dimIdx)) {
2284 // This dim is not all-true and since this is a dynamic index we don't
2285 // know if the extraction is within the true or false region.
2286 // Note: Zero dims have already handled via getMaskFormat().
2287 containsUnknownDims = true;
2288 }
2289 }
2290
2291 if (allFalse) {
2292 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2293 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2294 } else if (!containsUnknownDims) {
2295 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
2296 extractOp, extractedMaskType,
2297 maskOperands.drop_front(extractOpPos.size()));
2298 } else {
2299 return failure();
2300 }
2301 return success();
2302 }
2303};
2304
2305// Pattern to rewrite a ExtractOp(ConstantMask) -> ConstantMask.
2306class ExtractOpFromConstantMask final : public OpRewritePattern<ExtractOp> {
2307public:
2308 using Base::Base;
2309
2310 LogicalResult matchAndRewrite(ExtractOp extractOp,
2311 PatternRewriter &rewriter) const override {
2312 auto constantMaskOp =
2313 extractOp.getSource().getDefiningOp<vector::ConstantMaskOp>();
2314 if (!constantMaskOp)
2315 return failure();
2316
2317 Type resultType = extractOp.getResult().getType();
2318 auto extractedMaskType = dyn_cast<VectorType>(resultType);
2319
2320 ArrayRef<int64_t> extractOpPos = extractOp.getStaticPosition();
2321 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
2322
2323 VectorType maskType = constantMaskOp.getVectorType();
2324
2325 // Check if any extracted position is outside the mask bounds.
2326 for (size_t dimIdx = 0; dimIdx < extractOpPos.size(); dimIdx++) {
2327 int64_t pos = extractOpPos[dimIdx];
2328 if (pos == ShapedType::kDynamic) {
2329 // If the dim is all-true, a dynamic index is fine — any position
2330 // is within the masked region.
2331 if (maskDimSizes[dimIdx] == maskType.getDimSize(dimIdx))
2332 continue;
2333 // Otherwise we don't know if the position is inside or outside of
2334 // the masked area, so bail out.
2335 return failure();
2336 }
2337
2338 // If the position is statically outside of the masked area, the result
2339 // will be all-false.
2340 if (pos >= maskDimSizes[dimIdx]) {
2341 if (extractedMaskType) {
2342 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2343 extractOp, DenseElementsAttr::get(extractedMaskType, false));
2344 } else {
2345 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2346 extractOp, rewriter.getIntegerAttr(resultType, false));
2347 }
2348 return success();
2349 }
2350 }
2351
2352 // All positions are within the mask bounds.
2353 if (extractedMaskType) {
2354 // Vector result: the result is a constant_mask with the remaining
2355 // dimensions.
2356 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
2357 extractOp, extractedMaskType,
2358 maskDimSizes.drop_front(extractOpPos.size()));
2359 } else {
2360 // Scalar result: all positions are within the masked region, so the
2361 // result is true.
2362 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
2363 extractOp, rewriter.getIntegerAttr(resultType, true));
2364 }
2365 return success();
2366 }
2367};
2368
2369// Folds extract(shape_cast(..)) into shape_cast when the total element count
2370// does not change.
2371LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2372 PatternRewriter &rewriter) {
2373 auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
2374 if (!castOp)
2375 return failure();
2376
2377 VectorType sourceType = castOp.getSourceVectorType();
2378 auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2379 if (!targetType)
2380 return failure();
2381
2382 if (sourceType.getNumElements() != targetType.getNumElements())
2383 return failure();
2384
2385 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2386 castOp.getSource());
2387 return success();
2388}
2389
2390/// Try to canonicalize the extraction of a subvector from a vector defined by
2391/// vector.from_elements. E.g.:
2392///
2393/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32>
2394/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32>
2395/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32>
2396LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
2397 PatternRewriter &rewriter) {
2398 // Dynamic positions are not supported.
2399 if (extractOp.hasDynamicPosition())
2400 return failure();
2401
2402 // Scalar extracts are handled by the folder.
2403 auto resultType = dyn_cast<VectorType>(extractOp.getType());
2404 if (!resultType)
2405 return failure();
2406
2407 // Look for extracts from a from_elements op.
2408 auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
2409 if (!fromElementsOp)
2410 return failure();
2411 VectorType inputType = fromElementsOp.getType();
2412
2413 // Scalable vectors are not supported.
2414 if (resultType.isScalable() || inputType.isScalable())
2415 return failure();
2416
2417 // Compute the position of first extracted element and flatten/linearize the
2418 // position.
2419 SmallVector<int64_t> firstElementPos =
2420 llvm::to_vector(extractOp.getStaticPosition());
2421 firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0);
2422 int flatIndex = 0;
2423 int stride = 1;
2424 for (int64_t i = inputType.getRank() - 1; i >= 0; --i) {
2425 flatIndex += firstElementPos[i] * stride;
2426 stride *= inputType.getDimSize(i);
2427 }
2428
2429 // Replace the op with a smaller from_elements op.
2430 rewriter.replaceOpWithNewOp<FromElementsOp>(
2431 extractOp, resultType,
2432 fromElementsOp.getElements().slice(flatIndex,
2433 resultType.getNumElements()));
2434 return success();
2435}
2436
2437/// Replace `vector.extract` with `vector.shape_cast`.
2438///
2439/// BEFORE:
2440/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2441/// AFTER:
2442/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2443///
2444/// The canonical form of vector operations that reshape vectors is shape_cast.
2445struct ExtractToShapeCast final : OpRewritePattern<vector::ExtractOp> {
2446 using Base::Base;
2447 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2448 PatternRewriter &rewriter) const override {
2449 VectorType sourceType = extractOp.getSourceVectorType();
2450 VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2451 if (!outType)
2452 return failure();
2453
2454 if (sourceType.getNumElements() != outType.getNumElements())
2455 return rewriter.notifyMatchFailure(
2456 extractOp, "extract to vector with fewer elements");
2457
2458 // Negative values in `position` means that the extacted value is poison.
2459 // There is a vector.extract folder for this.
2460 if (llvm::any_of(extractOp.getMixedPosition(),
2461 [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2462 return rewriter.notifyMatchFailure(extractOp,
2463 "leaving for extract poison folder");
2464
2465 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
2466 extractOp.getSource());
2467
2468 return success();
2469 }
2470};
2471
2472/// Folds vector.extract from vector.insert when the extract position is a
2473/// prefix of the insert position and the remaining (un-indexed) dimensions
2474/// of the extracted sub-vector are all size 1. In that case the extracted
2475/// value is fully determined by the inserted value.
2476///
2477/// Examples:
2478/// %ins = vector.insert %s, %v [3, 0] : f32 into vector<16x1xf32>
2479/// %ext = vector.extract %ins [3] : vector<1xf32> from vector<16x1xf32>
2480/// folds to:
2481/// %ext = vector.broadcast %s : f32 to vector<1xf32>
2482///
2483/// %ins = vector.insert %s, %v [0, 0] : vector<1xf32> into vector<16x1x1xf32>
2484// %ext = vector.extract %ins [0] : vector<1x1xf32> from vector<16x1x1xf32>
2485/// folds to:
2486/// %ext = vector.shape_cast %arg0 : vector<1xf32> to vector<1x1xf32>
2487struct FoldExtractFromInsertUnitDim final
2488 : OpRewritePattern<vector::ExtractOp> {
2489 using Base::Base;
2490
2491 LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
2492 PatternRewriter &rewriter) const override {
2493 if (extractOp.hasDynamicPosition())
2494 return failure();
2495
2496 auto insertOp = extractOp.getSource().getDefiningOp<vector::InsertOp>();
2497 if (!insertOp || insertOp.hasDynamicPosition())
2498 return failure();
2499
2500 ArrayRef<int64_t> extractPos = extractOp.getStaticPosition();
2501 ArrayRef<int64_t> insertPos = insertOp.getStaticPosition();
2502
2503 // The extract position must be a strict prefix of the insert position.
2504 if (extractPos.size() >= insertPos.size() ||
2505 extractPos != insertPos.take_front(extractPos.size()))
2506 return failure();
2507
2508 // The remaining dimensions (those not indexed by the extract) must all
2509 // be size 1 in the source vector type. This guarantees that the inserted
2510 // value fully determines the extracted sub-vector.
2511 auto srcVecType = extractOp.getSourceVectorType();
2512 for (int64_t i = extractPos.size(), e = srcVecType.getRank(); i < e; ++i)
2513 if (srcVecType.getDimSize(i) != 1)
2514 return failure();
2515
2516 Value inserted = insertOp.getValueToStore();
2517 Type extractedType = extractOp.getResult().getType();
2518 if (isa<VectorType>(inserted.getType())) {
2519 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, extractedType,
2520 inserted);
2521 } else {
2522 // The inserted value fully determines the extracted sub-vector; broadcast
2523 // it to the extracted type.
2524 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2525 extractOp, extractOp.getResult().getType(),
2526 insertOp.getValueToStore());
2527 }
2528 return success();
2529 }
2530};
2531
2532} // namespace
2533
2534void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
2535 MLIRContext *context) {
2536 results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask,
2537 ExtractOpFromConstantMask, ExtractToShapeCast,
2538 FoldExtractFromInsertUnitDim>(context);
2539 results.add(foldExtractFromShapeCastToShapeCast);
2540 results.add(foldExtractFromFromElements);
2541}
2542
2544 SmallVectorImpl<int64_t> &results) {
2545 for (auto attr : arrayAttr)
2546 results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
2547}
2548
2549//===----------------------------------------------------------------------===//
2550// FmaOp
2551//===----------------------------------------------------------------------===//
2552
2553std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
2554 return llvm::to_vector<4>(getVectorType().getShape());
2555}
2556
2557//===----------------------------------------------------------------------===//
2558// ToElementsOp
2559//===----------------------------------------------------------------------===//
2560
2561/// Returns true if all the `operands` are defined by `defOp`.
2562/// Otherwise, returns false.
2563static bool haveSameDefiningOp(OperandRange operands, Operation *defOp) {
2564 if (operands.empty())
2565 return false;
2566
2567 return llvm::all_of(operands, [&](Value operand) {
2568 Operation *currentDef = operand.getDefiningOp();
2569 return currentDef == defOp;
2570 });
2571}
2572
2573/// Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2574/// (%e0, %e1, ...). For example:
2575///
2576/// %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2577/// %1:3 = vector.to_elements %0 : vector<3xf32>
2578/// user_op %1#0, %1#1, %1#2
2579///
2580/// becomes:
2581///
2582/// user_op %a, %b, %c
2583///
2584static LogicalResult
2585foldToElementsFromElements(ToElementsOp toElementsOp,
2587 auto fromElementsOp =
2588 toElementsOp.getSource().getDefiningOp<FromElementsOp>();
2589 if (!fromElementsOp)
2590 return failure();
2591
2592 llvm::append_range(results, fromElementsOp.getElements());
2593 return success();
2594}
2595
2596/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
2597///
2598/// Example:
2599/// %b = vector.broadcast %x : i32 to vector<3xf32>
2600/// %e:3 = vector.to_elements %b : vector<3xf32>
2601/// user_op %e#0, %e#1, %e#2
2602/// becomes:
2603/// user_op %x, %x, %x
2604///
2605/// The vector source case is handled by a canonicalization pattern.
2606static LogicalResult
2607foldToElementsOfBroadcast(ToElementsOp toElementsOp,
2609 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2610 if (!bcastOp)
2611 return failure();
2612 // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
2613 if (isa<VectorType>(bcastOp.getSource().getType()))
2614 return failure();
2615
2616 auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
2617
2618 Value scalar = bcastOp.getSource();
2619 results.assign(resultVecType.getNumElements(), scalar);
2620 return success();
2621}
2622
2623LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
2624 SmallVectorImpl<OpFoldResult> &results) {
2625 if (succeeded(foldToElementsFromElements(*this, results)))
2626 return success();
2627
2628 // Y = ToElements(ShapeCast(X)) -> Y = ToElements(X)
2629 if (auto shapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
2630 setOperand(shapeCast.getSource());
2631 return success();
2632 }
2633
2634 return foldToElementsOfBroadcast(*this, results);
2635}
2636
2637LogicalResult
2638ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
2639 ToElementsOp::Adaptor adaptor,
2640 SmallVectorImpl<Type> &inferredReturnTypes) {
2641 auto vecType = cast<VectorType>(adaptor.getSource().getType());
2642 Type elType = vecType.getElementType();
2643 inferredReturnTypes.append(vecType.getNumElements(), elType);
2644 return success();
2645}
2646
2647/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
2648/// vector.
2649/// - Build `vector.to_elements %v` and remap each destination element to the
2650/// corresponding source element using broadcast rules (match or 1 →
2651/// replicate).
2652///
2653/// Example:
2654/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
2655/// %e:6 = vector.to_elements %v : vector<3x2xf32>
2656/// becomes:
2657/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
2658/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2659/// // %src_elems#1, %src_elems#0, %src_elems#1
2660struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
2661 using Base::Base;
2662
2663 LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
2664 PatternRewriter &rewriter) const override {
2665 auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
2666 if (!bcastOp)
2667 return failure();
2668
2669 // Only handle broadcasts from a vector source here.
2670 auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
2671 if (!srcType)
2672 return failure();
2673
2674 auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
2675
2676 ArrayRef<int64_t> dstShape = dstType.getShape();
2677 ArrayRef<int64_t> srcShape = srcType.getShape();
2678
2679 int64_t dstRank = dstShape.size();
2680 int64_t srcRank = srcShape.size();
2681
2682 // Create elements for the broadcast source vector.
2683 auto srcElems = vector::ToElementsOp::create(
2684 rewriter, toElementsOp.getLoc(), bcastOp.getSource());
2685
2686 int64_t dstCount = llvm::product_of(dstShape);
2687
2688 SmallVector<Value> replacements;
2689 replacements.reserve(dstCount);
2690
2691 // For each element of the destination, determine which element of the
2692 // source should be used. We walk all destination positions using a single
2693 // counter, decode it into per-dimension indices, then build the matching
2694 // source position: use the same index where sizes match, and use 0 where
2695 // the source size is 1 (replication). This mapping is needed so we can
2696 // replace each result of to_elements with the corresponding element from
2697 // the broadcast source.
2698 // Inner-dimension stretch example:
2699 // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
2700 // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
2701 // becomes:
2702 // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
2703 // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
2704 // // %src_elems#1, %src_elems#0, %src_elems#1,
2705 // // %src_elems#2, %src_elems#3, %src_elems#2,
2706 // // %src_elems#3, %src_elems#2, %src_elems#3
2707
2708 // Row-major strides for the destination shape.
2709 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
2710 // Row-major strides for the source shape.
2711 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
2712 SmallVector<int64_t> dstIdx(dstRank);
2713 SmallVector<int64_t> srcIdx(srcRank);
2714 for (int64_t lin = 0; lin < dstCount; ++lin) {
2715 // Convert linear destination index to per-dimension indices.
2716 dstIdx = delinearize(lin, dstStrides);
2717 for (int64_t k = 0; k < srcRank; ++k)
2718 srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
2719 // Convert per-dimension source indices back to a linear index.
2720 int64_t srcLin = linearize(srcIdx, srcStrides);
2721 replacements.push_back(srcElems.getResult(srcLin));
2722 }
2723
2724 rewriter.replaceOp(toElementsOp, replacements);
2725 return success();
2726 }
2727};
2728
2729void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2730 MLIRContext *context) {
2731 results.add<ToElementsOfBroadcast>(context);
2732}
2733
2734//===----------------------------------------------------------------------===//
2735// FromElementsOp
2736//===----------------------------------------------------------------------===//
2737
2738/// Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2739///
2740/// Case #1: Input and output vectors are the same.
2741///
2742/// %0:3 = vector.to_elements %a : vector<3xf32>
2743/// %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2744/// user_op %1
2745///
2746/// becomes:
2747///
2748/// user_op %a
2749///
2750static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
2751 OperandRange fromElemsOperands = fromElementsOp.getElements();
2752 if (fromElemsOperands.empty())
2753 return {};
2754
2755 auto toElementsOp = fromElemsOperands[0].getDefiningOp<ToElementsOp>();
2756 if (!toElementsOp)
2757 return {};
2758
2759 if (!haveSameDefiningOp(fromElemsOperands, toElementsOp))
2760 return {};
2761
2762 // Case #1: Input and output vectors are the same. Forward the input vector.
2763 Value toElementsInput = toElementsOp.getSource();
2764 if (fromElementsOp.getType() == toElementsInput.getType() &&
2765 llvm::equal(fromElemsOperands, toElementsOp.getResults())) {
2766 return toElementsInput;
2767 }
2768
2769 // TODO: Support cases with different input and output shapes and different
2770 // number of elements.
2771
2772 return {};
2773}
2774
2775/// Fold vector.from_elements to a constant when all operands are constants.
2776/// Example:
2777/// %c1 = arith.constant 1 : i32
2778/// %c2 = arith.constant 2 : i32
2779/// %v = vector.from_elements %c1, %c2 : vector<2xi32>
2780/// =>
2781/// %v = arith.constant dense<[1, 2]> : vector<2xi32>
2782///
2783static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
2784 ArrayRef<Attribute> elements) {
2785 // Check for null or poison attributes before any processing.
2786 if (llvm::any_of(elements, [](Attribute attr) {
2787 return !attr || matchPattern(attr, ub::m_Poison());
2788 }))
2789 return {};
2790
2791 // DenseElementsAttr only supports int/index/float/complex types.
2792 auto destVecType = fromElementsOp.getDest().getType();
2793 auto destEltType = destVecType.getElementType();
2794 if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType))
2795 return {};
2796
2797 // Constant attributes might have a different type than the return type.
2798 // Convert them before creating the dense elements attribute.
2799 auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) {
2800 return convertNumericAttr(attr, destEltType);
2801 });
2802
2803 return DenseElementsAttr::get(destVecType, convertedElements);
2804}
2805
2806OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2807 if (auto res = foldFromElementsToElements(*this))
2808 return res;
2809 if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
2810 return res;
2811
2812 return {};
2813}
2814
2815/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2816/// same. Example:
2817/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2818/// =>
2819/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2820static LogicalResult
2821rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2822 PatternRewriter &rewriter) {
2823 if (!llvm::all_equal(fromElementsOp.getElements()))
2824 return failure();
2825 rewriter.replaceOpWithNewOp<BroadcastOp>(
2826 fromElementsOp, fromElementsOp.getType(),
2827 fromElementsOp.getElements().front());
2828 return success();
2829}
2830
2831/// Rewrite from_elements on multiple scalar extracts as a shape_cast
2832/// on a single extract. Example:
2833/// %0 = vector.extract %source[0, 0] : i8 from vector<2x2xi8>
2834/// %1 = vector.extract %source[0, 1] : i8 from vector<2x2xi8>
2835/// %2 = vector.from_elements %0, %1 : vector<2xi8>
2836///
2837/// becomes
2838/// %1 = vector.extract %source[0] : vector<1x2xi8> from vector<2x2xi8>
2839/// %2 = vector.shape_cast %1 : vector<1x2xi8> to vector<2xi8>
2840///
2841/// The requirements for this to be valid are
2842///
2843/// i) The elements are extracted from the same vector (%source).
2844///
2845/// ii) The elements form a suffix of %source. Specifically, the number
2846/// of elements is the same as the product of the last N dimension sizes
2847/// of %source, for some N.
2848///
2849/// iii) The elements are extracted contiguously in ascending order.
2850
2851class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2852
2853 using Base::Base;
2854
2855 LogicalResult matchAndRewrite(FromElementsOp fromElements,
2856 PatternRewriter &rewriter) const override {
2857
2858 // Handled by `rewriteFromElementsAsBroadcast`.
2859 if (fromElements.getType().getNumElements() == 1)
2860 return failure();
2861
2862 // The common source that all elements are extracted from, if one exists.
2864 // The position of the combined extract operation, if one is created.
2865 ArrayRef<int64_t> combinedPosition;
2866 // The expected index of extraction of the current element in the loop, if
2867 // elements are extracted contiguously in ascending order.
2868 SmallVector<int64_t> expectedPosition;
2869
2870 for (auto [insertIndex, element] :
2871 llvm::enumerate(fromElements.getElements())) {
2872
2873 // Check that the element is from a vector.extract operation.
2874 auto extractOp = element.getDefiningOp<vector::ExtractOp>();
2875 if (!extractOp) {
2876 return rewriter.notifyMatchFailure(fromElements,
2877 "element not from vector.extract");
2878 }
2879
2880 // Check condition (i) by checking that all elements have the same source
2881 // as the first element.
2882 if (insertIndex == 0) {
2883 source = extractOp.getSource();
2884 } else if (extractOp.getSource() != source) {
2885 return rewriter.notifyMatchFailure(fromElements,
2886 "element from different vector");
2887 }
2888
2889 ArrayRef<int64_t> position = extractOp.getStaticPosition();
2890 int64_t rank = position.size();
2891 assert(rank == source.getType().getRank() &&
2892 "scalar extract must have full rank position");
2893
2894 // Check condition (ii) by checking that the position that the first
2895 // element is extracted from has sufficient trailing 0s. For example, in
2896 //
2897 // %elm0 = vector.extract %source[1, 0, 0] : i8 from vector<2x3x4xi8>
2898 // [...]
2899 // %elms = vector.from_elements %elm0, [...] : vector<12xi8>
2900 //
2901 // The 2 trailing 0s in the position of extraction of %elm0 cover 3*4 = 12
2902 // elements, which is the number of elements of %n, so this is valid.
2903 if (insertIndex == 0) {
2904 const int64_t numElms = fromElements.getType().getNumElements();
2905 int64_t numSuffixElms = 1;
2906 int64_t index = rank;
2907 while (index > 0 && position[index - 1] == 0 &&
2908 numSuffixElms < numElms) {
2909 numSuffixElms *= source.getType().getDimSize(index - 1);
2910 --index;
2911 }
2912 if (numSuffixElms != numElms) {
2913 return rewriter.notifyMatchFailure(
2914 fromElements, "elements do not form a suffix of source");
2915 }
2916 expectedPosition = llvm::to_vector(position);
2917 combinedPosition = position.drop_back(rank - index);
2918 }
2919
2920 // Check condition (iii).
2921 else if (expectedPosition != position) {
2922 return rewriter.notifyMatchFailure(
2923 fromElements, "elements not in ascending order (static order)");
2924 }
2925 increment(expectedPosition, source.getType().getShape());
2926 }
2927
2928 auto extracted = rewriter.createOrFold<vector::ExtractOp>(
2929 fromElements.getLoc(), source, combinedPosition);
2930
2931 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
2932 fromElements, fromElements.getType(), extracted);
2933
2934 return success();
2935 }
2936
2937 /// Increments n-D `indices` by 1 starting from the innermost dimension.
2938 static void increment(MutableArrayRef<int64_t> indices,
2940 for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
2941 indices[dim] += 1;
2942 if (indices[dim] < shape[dim])
2943 break;
2944 indices[dim] = 0;
2945 }
2946 }
2947};
2948
2949void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
2950 MLIRContext *context) {
2952 results.add<FromElementsToShapeCast>(context);
2953}
2954
2955//===----------------------------------------------------------------------===//
2956// BroadcastOp
2957//===----------------------------------------------------------------------===//
2958
2959void BroadcastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
2960 SetIntRangeFn setResultRanges) {
2961 setResultRanges(getResult(), argRanges.front());
2962}
2963
2964std::optional<SmallVector<int64_t, 4>> BroadcastOp::getShapeForUnroll() {
2965 return llvm::to_vector<4>(getResultVectorType().getShape());
2966}
2967
2968/// Return the dimensions of the result vector that were formerly ones in the
2969/// source tensor and thus correspond to "dim-1" broadcasting.
2970static llvm::SetVector<int64_t>
2972 ArrayRef<int64_t> dstShape) {
2973 int64_t rankDiff = dstShape.size() - srcShape.size();
2974 int64_t dstDim = rankDiff;
2976 for (auto [s1, s2] :
2977 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
2978 if (s1 != s2) {
2979 assert(s1 == 1 && "expected \"dim-1\" broadcasting");
2980 res.insert(dstDim);
2981 }
2982 ++dstDim;
2983 }
2984 return res;
2985}
2986
2987llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
2988 // Scalar broadcast is without any unit dim broadcast.
2989 auto srcVectorType = llvm::dyn_cast<VectorType>(getSourceType());
2990 if (!srcVectorType)
2991 return {};
2992 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
2993 getResultVectorType().getShape());
2994}
2995
2996/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
2997/// `broadcastedDims` dimensions in the dstShape are broadcasted.
2998/// This requires (and asserts) that the broadcast is free of "dim-1"
2999/// broadcasting.
3000/// Since vector.broadcast only allows expanding leading dimensions, an extra
3001/// vector.transpose may be inserted to make the broadcast possible.
3002/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
3003/// the helper will assert. This means:
3004/// 1. `dstShape` must not be empty.
3005/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
3006/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
3007// must match the `value` shape.
3008Value BroadcastOp::createOrFoldBroadcastOp(
3009 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
3010 const llvm::SetVector<int64_t> &broadcastedDims) {
3011 assert(!dstShape.empty() && "unexpected empty dst shape");
3012
3013 // Well-formedness check.
3014 SmallVector<int64_t> checkShape;
3015 for (int i = 0, e = dstShape.size(); i < e; ++i) {
3016 if (broadcastedDims.contains(i))
3017 continue;
3018 checkShape.push_back(dstShape[i]);
3019 }
3020 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&
3021 "ill-formed broadcastedDims contains values not confined to "
3022 "destVectorShape");
3023
3024 Location loc = value.getLoc();
3025 Type elementType = getElementTypeOrSelf(value.getType());
3026 VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
3027 VectorType dstVectorType = VectorType::get(dstShape, elementType);
3028
3029 // Step 2. If scalar -> dstShape broadcast, just do it.
3030 if (!srcVectorType) {
3031 assert(checkShape.empty() &&
3032 "ill-formed createOrFoldBroadcastOp arguments");
3033 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
3034 }
3035
3036 assert(srcVectorType.getShape().equals(checkShape) &&
3037 "ill-formed createOrFoldBroadcastOp arguments");
3038
3039 // Step 3. Since vector.broadcast only allows creating leading dims,
3040 // vector -> dstShape broadcast may require a transpose.
3041 // Traverse the dims in order and construct:
3042 // 1. The leading entries of the broadcastShape that is guaranteed to be
3043 // achievable by a simple broadcast.
3044 // 2. The induced permutation for the subsequent vector.transpose that will
3045 // bring us from `broadcastShape` back to he desired `dstShape`.
3046 // If the induced permutation is not the identity, create a vector.transpose.
3047 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
3048 broadcastShape.reserve(dstShape.size());
3049 // Consider the example:
3050 // srcShape = 2x4
3051 // dstShape = 1x2x3x4x5
3052 // broadcastedDims = [0, 2, 4]
3053 //
3054 // We want to build:
3055 // broadcastShape = 1x3x5x2x4
3056 // permutation = [0, 2, 4, 1, 3]
3057 // ---V--- -----V-----
3058 // leading broadcast part src shape part
3059 //
3060 // Note that the trailing dims of broadcastShape are exactly the srcShape
3061 // by construction.
3062 // nextSrcShapeDim is used to keep track of where in the permutation the
3063 // "src shape part" occurs.
3064 int64_t nextSrcShapeDim = broadcastedDims.size();
3065 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
3066 if (broadcastedDims.contains(i)) {
3067 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
3068 // bring it to the head of the broadcastShape.
3069 // It will need to be permuted back from `broadcastShape.size() - 1` into
3070 // position `i`.
3071 broadcastShape.push_back(dstShape[i]);
3072 permutation[i] = broadcastShape.size() - 1;
3073 } else {
3074 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
3075 // shape and needs to be permuted into position `i`.
3076 // Don't touch `broadcastShape` here, the whole srcShape will be
3077 // appended after.
3078 permutation[i] = nextSrcShapeDim++;
3079 }
3080 }
3081 // 3.c. Append the srcShape.
3082 llvm::append_range(broadcastShape, srcVectorType.getShape());
3083
3084 // Ensure there are no "dim-1" broadcasts.
3085 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)
3086 .empty() &&
3087 "unexpected \"dim-1\" broadcast");
3088
3089 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
3090 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
3091 vector::BroadcastableToResult::Success &&
3092 "must be broadcastable");
3093 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
3094 // Step 4. If we find any dimension that indeed needs to be permuted,
3095 // immediately return a new vector.transpose.
3096 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
3097 if (permutation[i] != i)
3098 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
3099 // Otherwise return res.
3100 return res;
3101}
3102
3104 Type srcType, VectorType dstVectorType,
3105 std::pair<VectorDim, VectorDim> *mismatchingDims) {
3106 // Broadcast scalar to vector of the same element type.
3107 if (isa<VectorElementTypeInterface>(srcType) && dstVectorType &&
3108 srcType == getElementTypeOrSelf(dstVectorType))
3110 // From now on, only vectors broadcast.
3111 VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType);
3112 if (!srcVectorType)
3114
3115 int64_t srcRank = srcVectorType.getRank();
3116 int64_t dstRank = dstVectorType.getRank();
3117 if (srcRank > dstRank)
3119 // Source has an exact match or singleton value for all trailing dimensions
3120 // (all leading dimensions are simply duplicated).
3121 int64_t lead = dstRank - srcRank;
3122 for (int64_t dimIdx = 0; dimIdx < srcRank; ++dimIdx) {
3123 // Have mismatching dims (in the sense of vector.broadcast semantics) been
3124 // encountered?
3125 bool foundMismatchingDims = false;
3126
3127 // Check fixed-width dims.
3128 int64_t srcDim = srcVectorType.getDimSize(dimIdx);
3129 int64_t dstDim = dstVectorType.getDimSize(lead + dimIdx);
3130 if (srcDim != 1 && srcDim != dstDim)
3131 foundMismatchingDims = true;
3132
3133 // Check scalable flags.
3134 bool srcDimScalableFlag = srcVectorType.getScalableDims()[dimIdx];
3135 bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + dimIdx];
3136 if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
3137 // 1 -> [N] is fine, everything else should be rejected when mixing
3138 // fixed-width and scalable dims
3139 (srcDimScalableFlag != dstDimScalableFlag &&
3140 (srcDim != 1 || srcDimScalableFlag)))
3141 foundMismatchingDims = true;
3142
3143 if (foundMismatchingDims) {
3144 if (mismatchingDims != nullptr) {
3145 mismatchingDims->first.dim = srcDim;
3146 mismatchingDims->first.isScalable = srcDimScalableFlag;
3147
3148 mismatchingDims->second.dim = dstDim;
3149 mismatchingDims->second.isScalable = dstDimScalableFlag;
3150 }
3152 }
3153 }
3154
3156}
3157
3158LogicalResult BroadcastOp::verify() {
3159 std::pair<VectorDim, VectorDim> mismatchingDims;
3161 getSourceType(), getResultVectorType(), &mismatchingDims);
3163 return success();
3165 return emitOpError("source rank higher than destination rank");
3167 return emitOpError("dimension mismatch (")
3168 << (mismatchingDims.first.isScalable ? "[" : "")
3169 << mismatchingDims.first.dim
3170 << (mismatchingDims.first.isScalable ? "]" : "") << " vs. "
3171 << (mismatchingDims.second.isScalable ? "[" : "")
3172 << mismatchingDims.second.dim
3173 << (mismatchingDims.second.isScalable ? "]" : "") << ")";
3174 }
3176 return emitOpError("source type is not a vector");
3177 llvm_unreachable("unexpected vector.broadcast op error");
3178}
3179
3180// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
3181// with broadcast's result type and shape_cast only adds or removes ones in the
3182// leading dimensions.
3183static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
3184 auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
3185 if (!srcShapeCast)
3186 return failure();
3187
3188 VectorType srcType = srcShapeCast.getSourceVectorType();
3189 VectorType destType = broadcastOp.getResultVectorType();
3190 // Check type compatibility.
3191 if (vector::isBroadcastableTo(srcType, destType) !=
3193 return failure();
3194
3195 ArrayRef<int64_t> srcShape = srcType.getShape();
3196 ArrayRef<int64_t> shapecastShape =
3197 srcShapeCast.getResultVectorType().getShape();
3198 // Trailing dimensions should be the same if shape_cast only alters the
3199 // leading dimensions.
3200 unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
3201 if (!llvm::equal(srcShape.take_back(numTrailingDims),
3202 shapecastShape.take_back(numTrailingDims)))
3203 return failure();
3204
3205 assert(all_of(srcShape.drop_back(numTrailingDims),
3206 [](int64_t E) { return E == 1; }) &&
3207 all_of(shapecastShape.drop_back(numTrailingDims),
3208 [](int64_t E) { return E == 1; }) &&
3209 "ill-formed shape_cast");
3210
3211 broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
3212 return success();
3213}
3214
3215OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
3216 if (getSourceType() == getResultVectorType())
3217 return getSource();
3218 if (succeeded(foldBroadcastOfShapeCast(*this)))
3219 return getResult();
3220
3221 if (!adaptor.getSource())
3222 return {};
3223 auto vectorType = getResultVectorType();
3224 if (auto attr = llvm::dyn_cast<IntegerAttr>(adaptor.getSource())) {
3225 if (vectorType.getElementType() != attr.getType())
3226 return {};
3227 return DenseElementsAttr::get(vectorType, attr);
3228 }
3229 if (auto attr = llvm::dyn_cast<FloatAttr>(adaptor.getSource())) {
3230 if (vectorType.getElementType() != attr.getType())
3231 return {};
3232 return DenseElementsAttr::get(vectorType, attr);
3233 }
3234 if (auto attr = llvm::dyn_cast<SplatElementsAttr>(adaptor.getSource()))
3235 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
3236 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
3237 return ub::PoisonAttr::get(getContext());
3238 return {};
3239}
3240
3241namespace {
3242
3243// Fold broadcast1(broadcast2(x)) into broadcast1(x).
3244struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
3245 using Base::Base;
3246
3247 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3248 PatternRewriter &rewriter) const override {
3249 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
3250 if (!srcBroadcast)
3251 return failure();
3252 rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
3253 broadcastOp.getResultVectorType(),
3254 srcBroadcast.getSource());
3255 return success();
3256 }
3257};
3258
3259/// Replace `vector.broadcast` with `vector.shape_cast`.
3260///
3261/// BEFORE:
3262/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3263/// AFTER:
3264/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
3265///
3266/// The canonical form of vector operations that reshape vectors is shape_cast.
3267struct BroadcastToShapeCast final
3268 : public OpRewritePattern<vector::BroadcastOp> {
3269 using Base::Base;
3270 LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
3271 PatternRewriter &rewriter) const override {
3272
3273 auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
3274 if (!sourceType) {
3275 return rewriter.notifyMatchFailure(
3276 broadcast, "source is a scalar, shape_cast doesn't support scalar");
3277 }
3278
3279 VectorType outType = broadcast.getType();
3280 if (sourceType.getNumElements() != outType.getNumElements()) {
3281 return rewriter.notifyMatchFailure(
3282 broadcast, "broadcast to a greater number of elements");
3283 }
3284
3285 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
3286 broadcast.getSource());
3287 return success();
3288 }
3289};
3290} // namespace
3291
3292void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
3293 MLIRContext *context) {
3294 results.add<BroadcastFolder, BroadcastToShapeCast>(context);
3295}
3296
3297//===----------------------------------------------------------------------===//
3298// ShuffleOp
3299//===----------------------------------------------------------------------===//
3300
3301LogicalResult ShuffleOp::verify() {
3302 VectorType resultType = getResultVectorType();
3303 VectorType v1Type = getV1VectorType();
3304 VectorType v2Type = getV2VectorType();
3305 // Verify ranks.
3306 int64_t resRank = resultType.getRank();
3307 int64_t v1Rank = v1Type.getRank();
3308 int64_t v2Rank = v2Type.getRank();
3309 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
3310 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
3311 if (!wellFormed0DCase && !wellFormedNDCase)
3312 return emitOpError("rank mismatch");
3313
3314 // Verify all but leading dimension sizes.
3315 for (int64_t r = 1; r < v1Rank; ++r) {
3316 int64_t resDim = resultType.getDimSize(r);
3317 int64_t v1Dim = v1Type.getDimSize(r);
3318 int64_t v2Dim = v2Type.getDimSize(r);
3319 if (resDim != v1Dim || v1Dim != v2Dim)
3320 return emitOpError("dimension mismatch");
3321 }
3322 // Verify mask length.
3323 ArrayRef<int64_t> mask = getMask();
3324 int64_t maskLength = mask.size();
3325 if (maskLength <= 0)
3326 return emitOpError("invalid mask length");
3327 if (maskLength != resultType.getDimSize(0))
3328 return emitOpError("mask length mismatch");
3329 // Verify all indices.
3330 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
3331 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
3332 for (auto [idx, maskPos] : llvm::enumerate(mask)) {
3333 if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize))
3334 return emitOpError("mask index #") << (idx + 1) << " out of range";
3335 }
3336 return success();
3337}
3338
3339LogicalResult
3340ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location> loc,
3341 ShuffleOp::Adaptor adaptor,
3342 SmallVectorImpl<Type> &inferredReturnTypes) {
3343 auto v1Type = llvm::dyn_cast<VectorType>(adaptor.getV1().getType());
3344 if (!v1Type) {
3345 return emitOptionalError(loc, "expected vector type");
3346 }
3347 auto v1Rank = v1Type.getRank();
3348 // Construct resulting type: leading dimension matches mask
3349 // length, all trailing dimensions match the operands.
3350 SmallVector<int64_t, 4> shape;
3351 shape.reserve(v1Rank);
3352 shape.push_back(std::max<size_t>(1, adaptor.getMask().size()));
3353 // In the 0-D case there is no trailing shape to append.
3354 if (v1Rank > 0)
3355 llvm::append_range(shape, v1Type.getShape().drop_front());
3356 inferredReturnTypes.push_back(
3357 VectorType::get(shape, v1Type.getElementType()));
3358 return success();
3359}
3360
3361template <typename T>
3362static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
3363 T expected = begin;
3364 return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
3365 return value == expected++;
3366 });
3367}
3368
3369/// Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
3370/// Fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2.
3372 auto v1Type = op.getV1VectorType();
3373 auto v2Type = op.getV2VectorType();
3374 auto mask = op.getMask();
3375 if (isStepIndexArray(mask, 0, v1Type.getDimSize(0)))
3376 return op.getV1();
3377 if (isStepIndexArray(mask, v1Type.getDimSize(0), v2Type.getDimSize(0)))
3378 return op.getV2();
3379 return {};
3380}
3381
3382/// If a shuffle operand is poison, replace all mask indices that reference it
3383/// with kPoisonIndex. This is an in-place fold.
3385 bool isV1Poison = matchPattern(op.getV1(), ub::m_Poison());
3386 bool isV2Poison = matchPattern(op.getV2(), ub::m_Poison());
3387 if (!isV1Poison && !isV2Poison)
3388 return {};
3389
3390 int64_t v1Size = op.getV1VectorType().getDimSize(0);
3391 bool changed = false;
3392 SmallVector<int64_t> newMask = llvm::to_vector(op.getMask());
3393 for (int64_t &idx : newMask) {
3394 if (idx == ShuffleOp::kPoisonIndex)
3395 continue;
3396 if ((isV1Poison && idx < v1Size) || (isV2Poison && idx >= v1Size)) {
3397 idx = ShuffleOp::kPoisonIndex;
3398 changed = true;
3399 }
3400 }
3401
3402 if (!changed)
3403 return {};
3404
3405 op.setMask(newMask);
3406 return op.getResult();
3407}
3408
3409/// Fold shuffle poison, poison -> poison.
3411 Attribute v1Attr,
3412 Attribute v2Attr) {
3413 if (matchPattern(v1Attr, ub::m_Poison()) &&
3414 matchPattern(v2Attr, ub::m_Poison()))
3415 return ub::PoisonAttr::get(context);
3416 return {};
3417}
3418
3419/// Fold a shuffle of constant 1-D inputs by evaluating the mask.
3421 Attribute v2Attr) {
3422 auto v1Type = op.getV1VectorType();
3423 if (v1Type.getRank() != 1)
3424 return {};
3425
3426 bool isV1Poison = matchPattern(v1Attr, ub::m_Poison());
3427 bool isV2Poison = matchPattern(v2Attr, ub::m_Poison());
3428
3429 // Poison input attributes need special handling as they are not
3430 // DenseElementsAttr. If an index is poison, we select the first element of
3431 // the first non-poison input.
3432 SmallVector<Attribute> v1Elements, v2Elements;
3433 Attribute poisonElement;
3434 if (!isV2Poison) {
3435 auto v2DenseAttr = dyn_cast<DenseElementsAttr>(v2Attr);
3436 if (!v2DenseAttr)
3437 return {};
3438 v2Elements = to_vector(v2DenseAttr.getValues<Attribute>());
3439 poisonElement = v2Elements[0];
3440 }
3441 if (!isV1Poison) {
3442 auto v1DenseAttr = dyn_cast<DenseElementsAttr>(v1Attr);
3443 if (!v1DenseAttr)
3444 return {};
3445 v1Elements = to_vector(v1DenseAttr.getValues<Attribute>());
3446 poisonElement = v1Elements[0];
3447 }
3448
3449 ArrayRef<int64_t> mask = op.getMask();
3450 SmallVector<Attribute> results;
3451 int64_t v1Size = v1Type.getDimSize(0);
3452 for (int64_t maskIdx : mask) {
3453 Attribute indexedElm;
3454 // TODO: Return a partial poison vector when supported by the UB dialect.
3455 if (maskIdx == ShuffleOp::kPoisonIndex) {
3456 indexedElm = poisonElement;
3457 } else {
3458 if (maskIdx < v1Size)
3459 indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
3460 else
3461 indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
3462 }
3463
3464 results.push_back(indexedElm);
3465 }
3466
3467 return DenseElementsAttr::get(op.getResultVectorType(), results);
3468}
3469
3470OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
3471 auto v1Type = getV1VectorType();
3472
3473 assert(!v1Type.isScalable() && !getV2VectorType().isScalable() &&
3474 "Vector shuffle does not support scalable vectors");
3475
3476 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
3477 // but must be a canonicalization into a vector.broadcast.
3478 if (v1Type.getRank() == 0)
3479 return {};
3480
3481 if (auto res = foldShuffleIdentityMask(*this))
3482 return res;
3483 if (auto res = foldShufflePoisonOperandToMask(*this))
3484 return res;
3485
3486 Attribute v1Attr = adaptor.getV1(), v2Attr = adaptor.getV2();
3487 if (!v1Attr || !v2Attr)
3488 return {};
3489
3490 if (auto res = foldShufflePoisonInputs(getContext(), v1Attr, v2Attr))
3491 return res;
3492 if (auto res = foldShuffleConstantInputs(*this, v1Attr, v2Attr))
3493 return res;
3494
3495 return {};
3496}
3497
3498namespace {
3499
3500// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
3501// to a broadcast.
3502struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3503 using Base::Base;
3504
3505 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
3506 PatternRewriter &rewriter) const override {
3507 VectorType v1VectorType = shuffleOp.getV1VectorType();
3508 ArrayRef<int64_t> mask = shuffleOp.getMask();
3509 if (v1VectorType.getRank() > 0)
3510 return failure();
3511 if (mask.size() != 1)
3512 return failure();
3513 VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
3514 if (mask[0] == 0)
3515 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3516 shuffleOp.getV1());
3517 else
3518 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
3519 shuffleOp.getV2());
3520 return success();
3521 }
3522};
3523
3524/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3525/// vector.broadcast with a scalar operand, return the scalar value that is
3526/// splatted. Otherwise return null.
3527///
3528/// Example:
3529///
3530/// scalar_source --> vector.broadcast --> value - return scalar_source
3531static Value getScalarSplatSource(Value value) {
3532 // Block argument:
3533 Operation *defOp = value.getDefiningOp();
3534 if (!defOp)
3535 return {};
3536
3537 auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3538
3539 // Not broadcast (and not splat):
3540 if (!broadcast)
3541 return {};
3542
3543 // Broadcast of a vector:
3544 if (isa<VectorType>(broadcast.getSourceType()))
3545 return {};
3546
3547 // Broadcast of a scalar:
3548 return broadcast.getSource();
3549}
3550
3551/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3552class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3553public:
3554 using Base::Base;
3555
3556 LogicalResult matchAndRewrite(ShuffleOp op,
3557 PatternRewriter &rewriter) const override {
3558 Value splat = getScalarSplatSource(op.getV1());
3559 if (!splat || getScalarSplatSource(op.getV2()) != splat)
3560 return failure();
3561
3562 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3563 return success();
3564 }
3565};
3566
3567/// Pattern to rewrite a fixed-size interleave via vector.shuffle to
3568/// vector.interleave.
3569class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
3570public:
3571 using Base::Base;
3572
3573 LogicalResult matchAndRewrite(ShuffleOp op,
3574 PatternRewriter &rewriter) const override {
3575 VectorType resultType = op.getResultVectorType();
3576 if (resultType.isScalable())
3577 return rewriter.notifyMatchFailure(
3578 op, "ShuffleOp can't represent a scalable interleave");
3579
3580 if (resultType.getRank() != 1)
3581 return rewriter.notifyMatchFailure(
3582 op, "ShuffleOp can't represent an n-D interleave");
3583
3584 VectorType sourceType = op.getV1VectorType();
3585 if (sourceType != op.getV2VectorType() ||
3586 sourceType.getNumElements() * 2 != resultType.getNumElements()) {
3587 return rewriter.notifyMatchFailure(
3588 op, "ShuffleOp types don't match an interleave");
3589 }
3590
3591 ArrayRef<int64_t> shuffleMask = op.getMask();
3592 int64_t resultVectorSize = resultType.getNumElements();
3593 for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
3594 int64_t maskValueA = shuffleMask[i * 2];
3595 int64_t maskValueB = shuffleMask[(i * 2) + 1];
3596 if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
3597 return rewriter.notifyMatchFailure(op,
3598 "ShuffleOp mask not interleaving");
3599 }
3600
3601 rewriter.replaceOpWithNewOp<InterleaveOp>(op, op.getV1(), op.getV2());
3602 return success();
3603 }
3604};
3605
3606/// Pattern to replace usused shuffle operands / results with poison.
3607///
3608/// Example Input:
3609/// %r = vector.shuffle %v1, %v2 [2, 3, 3, 3] : vector<2xi32>, vector<2xi32>
3610///
3611/// Example Output:
3612/// %0 = ub.poison : vector<2xi32>
3613/// %r = vector.shuffle %0, %v2 [2, 3, 3, 3] : vector<2xi32>, vector<2xi32>
3614class FoldUnusedShuffleOperand final : public OpRewritePattern<ShuffleOp> {
3615public:
3616 using Base::Base;
3617
3618 LogicalResult matchAndRewrite(ShuffleOp op,
3619 PatternRewriter &rewriter) const override {
3620 // Replace with poison if all mask elements are poison.
3621 if (llvm::all_of(op.getMask(), [](int64_t mask) {
3622 return mask == ShuffleOp::kPoisonIndex;
3623 })) {
3624 rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType());
3625 return success();
3626 }
3627
3628 // Helper function to replace an operand with poison.
3629 auto replaceOperandWithPoison = [&](OpOperand &operand) {
3630 // Do not replace if the operand is already poison.
3631 if (!matchPattern(operand.get(), ub::m_Poison())) {
3632 Value poison = ub::PoisonOp::create(rewriter, op.getLoc(),
3633 operand.get().getType());
3634 rewriter.modifyOpInPlace(op, [&]() { operand.set(poison); });
3635 return success();
3636 }
3637 return failure();
3638 };
3639
3640 // Replace V1 with poison if it is not used.
3641 int64_t leadingV1Size = op.getV1VectorType().getRank() > 0
3642 ? op.getV1VectorType().getDimSize(0)
3643 : 1;
3644 bool isV1Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3645 return mask != ShuffleOp::kPoisonIndex && mask < leadingV1Size;
3646 });
3647 if (!isV1Used && succeeded(replaceOperandWithPoison(op.getV1Mutable())))
3648 return success();
3649
3650 // Replace V2 with poison if it is not used.
3651 bool isV2Used = llvm::any_of(op.getMask(), [&](int64_t mask) {
3652 return mask != ShuffleOp::kPoisonIndex && mask >= leadingV1Size;
3653 });
3654 if (!isV2Used && succeeded(replaceOperandWithPoison(op.getV2Mutable())))
3655 return success();
3656
3657 return failure();
3658 }
3659};
3660} // namespace
3661
3662void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
3663 MLIRContext *context) {
3664 results.add<ShuffleSplat, ShuffleInterleave, Canonicalize0DShuffleOp,
3665 FoldUnusedShuffleOperand>(context);
3666}
3667
3668//===----------------------------------------------------------------------===//
3669// InsertOp
3670//===----------------------------------------------------------------------===//
3671
3672void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3673 SetIntRangeFn setResultRanges) {
3674 setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3675}
3676
3677void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3678 Value source, Value dest) {
3679 auto vectorTy = cast<VectorType>(dest.getType());
3680 build(builder, result, source, dest,
3681 SmallVector<int64_t>(vectorTy.getRank(), 0));
3682}
3683
3684void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3685 Value source, Value dest, int64_t position) {
3686 build(builder, result, source, dest, ArrayRef<int64_t>{position});
3687}
3688
3689void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3690 Value source, Value dest, OpFoldResult position) {
3691 build(builder, result, source, dest, ArrayRef<OpFoldResult>{position});
3692}
3693
3694void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3695 Value source, Value dest,
3696 ArrayRef<int64_t> position) {
3697 SmallVector<OpFoldResult> posVals;
3698 posVals.reserve(position.size());
3699 llvm::transform(position, std::back_inserter(posVals),
3700 [&](int64_t pos) { return builder.getI64IntegerAttr(pos); });
3701 build(builder, result, source, dest, posVals);
3702}
3703
3704void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
3705 Value source, Value dest,
3706 ArrayRef<OpFoldResult> position) {
3707 SmallVector<int64_t> staticPos;
3708 SmallVector<Value> dynamicPos;
3709 dispatchIndexOpFoldResults(position, dynamicPos, staticPos);
3710 build(builder, result, source, dest, dynamicPos,
3711 builder.getDenseI64ArrayAttr(staticPos));
3712}
3713
3714LogicalResult InsertOp::verify() {
3715 if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
3716 if (srcTy.getRank() == 0)
3717 return emitError(
3718 "expected a scalar instead of a 0-d vector as the source operand");
3719
3720 SmallVector<OpFoldResult> position = getMixedPosition();
3721 auto destVectorType = getDestVectorType();
3722 if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
3723 return emitOpError(
3724 "expected position attribute of rank no greater than dest vector rank");
3725 auto srcVectorType = llvm::dyn_cast<VectorType>(getValueToStoreType());
3726 if (srcVectorType &&
3727 (static_cast<unsigned>(srcVectorType.getRank()) + position.size() !=
3728 static_cast<unsigned>(destVectorType.getRank())))
3729 return emitOpError("expected position attribute rank + source rank to "
3730 "match dest vector rank");
3731 if (!srcVectorType &&
3732 (position.size() != static_cast<unsigned>(destVectorType.getRank())))
3733 return emitOpError(
3734 "expected position attribute rank to match the dest vector rank");
3735 for (auto [idx, pos] : llvm::enumerate(position)) {
3736 if (auto attr = dyn_cast<Attribute>(pos)) {
3737 int64_t constIdx = cast<IntegerAttr>(attr).getInt();
3738 if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex,
3739 destVectorType.getDimSize(idx))) {
3740 return emitOpError("expected position attribute #")
3741 << (idx + 1)
3742 << " to be a non-negative integer smaller than the "
3743 "corresponding "
3744 "dest vector dimension";
3745 }
3746 }
3747 }
3748 return success();
3749}
3750
3751// Calculate the linearized position of the continuous chunk of elements to
3752// insert, based on the shape of the value to insert and the positions to insert
3753// at.
3754static int64_t calculateInsertPosition(VectorType destTy,
3755 ArrayRef<int64_t> positions) {
3756 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3757 assert(positions.size() <= completePositions.size() &&
3758 "positions size must be less than or equal to destTy rank");
3759 copy(positions, completePositions.begin());
3760 return linearize(completePositions, computeStrides(destTy.getShape()));
3761}
3762
3763namespace {
3764
3765// If insertOp is only inserting unit dimensions it can be transformed to a
3766// broadcast.
3767class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3768public:
3769 using Base::Base;
3770
3771 LogicalResult matchAndRewrite(InsertOp insertOp,
3772 PatternRewriter &rewriter) const override {
3773 auto srcVecType =
3774 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType());
3775 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3776 srcVecType.getNumElements())
3777 return failure();
3778 rewriter.replaceOpWithNewOp<BroadcastOp>(
3779 insertOp, insertOp.getDestVectorType(), insertOp.getValueToStore());
3780 return success();
3781 }
3782};
3783
3784/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
3785class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3786public:
3787 using Base::Base;
3788
3789 LogicalResult matchAndRewrite(InsertOp op,
3790 PatternRewriter &rewriter) const override {
3791
3792 Value splat = getScalarSplatSource(op.getValueToStore());
3793 if (!splat || getScalarSplatSource(op.getDest()) != splat)
3794 return failure();
3795
3796 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
3797 return success();
3798 }
3799};
3800
3801/// Pattern to optimize a chain of insertions.
3802///
3803/// This pattern identifies chains of vector.insert operations that:
3804/// 1. Only insert values at static positions.
3805/// 2. Completely initialize all elements in the resulting vector.
3806/// 3. All intermediate insert operations have only one use.
3807///
3808/// When these conditions are met, the entire chain can be replaced with a
3809/// single vector.from_elements operation.
3810///
3811/// To keep this pattern simple, and avoid spending too much time on matching
3812/// fragmented insert chains, this pattern only considers the last insert op in
3813/// the chain.
3814///
3815/// Example transformation:
3816/// %poison = ub.poison : vector<2xi32>
3817/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
3818/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
3819/// ->
3820/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
3821class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
3822public:
3823 using Base::Base;
3824 LogicalResult matchAndRewrite(InsertOp op,
3825 PatternRewriter &rewriter) const override {
3826
3827 VectorType destTy = op.getDestVectorType();
3828 if (destTy.isScalable())
3829 return failure();
3830 // Ensure this is the trailing vector.insert op in a chain of inserts.
3831 for (Operation *user : op.getResult().getUsers())
3832 if (auto insertOp = dyn_cast<InsertOp>(user))
3833 if (insertOp.getDest() == op.getResult())
3834 return failure();
3835
3836 InsertOp currentOp = op;
3837 SmallVector<InsertOp> chainInsertOps;
3838 while (currentOp) {
3839 // Check cond 1: Dynamic position is not supported.
3840 if (currentOp.hasDynamicPosition())
3841 return failure();
3842
3843 chainInsertOps.push_back(currentOp);
3844 currentOp = currentOp.getDest().getDefiningOp<InsertOp>();
3845 // Check cond 3: Intermediate inserts have only one use to avoid an
3846 // explosion of vectors.
3847 if (currentOp && !currentOp->hasOneUse())
3848 return failure();
3849 }
3850
3851 int64_t vectorSize = destTy.getNumElements();
3852 int64_t initializedCount = 0;
3853 SmallVector<bool> initializedDestIdxs(vectorSize, false);
3854 SmallVector<int64_t> pendingInsertPos;
3855 SmallVector<int64_t> pendingInsertSize;
3856 SmallVector<Value> pendingInsertValues;
3857
3858 for (auto insertOp : chainInsertOps) {
3859 // This pattern can do nothing with poison index.
3860 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3861 return failure();
3862
3863 // Calculate the linearized position for inserting elements.
3864 int64_t insertBeginPosition =
3865 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3866
3867 // The valueToStore operand may be a vector or a scalar. Need to handle
3868 // both cases.
3869 int64_t insertSize = 1;
3870 if (auto srcVectorType =
3871 llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType()))
3872 insertSize = srcVectorType.getNumElements();
3873
3874 assert(insertBeginPosition + insertSize <= vectorSize &&
3875 "insert would overflow the vector");
3876
3877 for (auto index : llvm::seq<int64_t>(insertBeginPosition,
3878 insertBeginPosition + insertSize)) {
3879 if (initializedDestIdxs[index])
3880 continue;
3881 initializedDestIdxs[index] = true;
3882 ++initializedCount;
3883 }
3884
3885 // Defer the creation of ops before we can make sure the pattern can
3886 // succeed.
3887 pendingInsertPos.push_back(insertBeginPosition);
3888 pendingInsertSize.push_back(insertSize);
3889 pendingInsertValues.push_back(insertOp.getValueToStore());
3890
3891 if (initializedCount == vectorSize)
3892 break;
3893 }
3894
3895 // Check cond 2: all positions must be initialized.
3896 if (initializedCount != vectorSize)
3897 return failure();
3898
3899 SmallVector<Value> elements(vectorSize);
3900 for (auto [insertBeginPosition, insertSize, valueToStore] :
3901 llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize,
3902 pendingInsertValues))) {
3903 auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType());
3904
3905 if (!srcVectorType) {
3906 elements[insertBeginPosition] = valueToStore;
3907 continue;
3908 }
3909
3910 Repeated<Type> elementToInsertTypes(insertSize,
3911 srcVectorType.getElementType());
3912 // Get all elements from the vector in row-major order.
3913 auto elementsToInsert = vector::ToElementsOp::create(
3914 rewriter, op.getLoc(), elementToInsertTypes, valueToStore);
3915 for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) {
3916 elements[insertBeginPosition + linearIdx] =
3917 elementsToInsert.getResult(linearIdx);
3918 }
3919 }
3920
3921 rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements);
3922 return success();
3923 }
3924};
3925
3926} // namespace
3927
3928static Attribute
3930 Attribute dstAttr,
3931 int64_t maxVectorSizeFoldThreshold) {
3932 if (insertOp.hasDynamicPosition())
3933 return {};
3934
3935 auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3936 if (!denseDst)
3937 return {};
3938
3939 if (!srcAttr) {
3940 return {};
3941 }
3942
3943 VectorType destTy = insertOp.getDestVectorType();
3944 if (destTy.isScalable())
3945 return {};
3946
3947 // Make sure we do not create too many large constants.
3948 if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3949 !insertOp->hasOneUse())
3950 return {};
3951
3952 // Bail out on poison indices (kPoisonIndex = -1) to avoid computing an
3953 // invalid (negative) linearized position which would cause UB below.
3954 if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex))
3955 return {};
3956
3957 // Calculate the linearized position for inserting elements.
3958 int64_t insertBeginPosition =
3959 calculateInsertPosition(destTy, insertOp.getStaticPosition());
3960 SmallVector<Attribute> insertedValues;
3961 Type destEltType = destTy.getElementType();
3962
3963 /// Converts attribute to the expected type if there's
3964 /// a mismatch.
3965 if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3966 for (auto value : denseSource.getValues<Attribute>())
3967 insertedValues.push_back(convertNumericAttr(value, destEltType));
3968 } else {
3969 insertedValues.push_back(convertNumericAttr(srcAttr, destEltType));
3970 }
3971
3972 auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3973 copy(insertedValues, allValues.begin() + insertBeginPosition);
3974 auto newAttr = DenseElementsAttr::get(destTy, allValues);
3975
3976 return newAttr;
3977}
3978
3979/// Folder to replace the `dest` operand of the insert op with the root dest of
3980/// the insert op use chain.
3981static Value foldInsertUseChain(InsertOp insertOp) {
3982 auto destInsert = insertOp.getDest().getDefiningOp<InsertOp>();
3983 if (!destInsert)
3984 return {};
3985
3986 if (insertOp.getMixedPosition() != destInsert.getMixedPosition())
3987 return {};
3988
3989 insertOp.setOperand(1, destInsert.getDest());
3990 return insertOp.getResult();
3991}
3992
3993void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3994 MLIRContext *context) {
3995 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3996 InsertChainFullyInitialized>(context);
3997}
3998
3999OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
4000 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
4001 // unless the source vector constant has a single use.
4002 constexpr int64_t vectorSizeFoldThreshold = 256;
4003 // Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
4004 // %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
4005 // (type mismatch).
4006 if (getNumIndices() == 0 && getValueToStoreType() == getType())
4007 return getValueToStore();
4008 // Fold `arith.constant` indices into the `vector.insert` operation.
4009 // Do not stop here as this fold may enable subsequent folds that require
4010 // constant indices.
4011 SmallVector<Value> operands = {getValueToStore(), getDest()};
4012 auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
4013
4014 if (auto res = foldInsertUseChain(*this))
4015 return res;
4016 if (auto res = foldPoisonIndexInsertExtractOp(
4017 getContext(), adaptor.getStaticPosition(), kPoisonIndex))
4018 return res;
4019 if (auto res = foldDenseElementsAttrDestInsertOp(
4020 *this, adaptor.getValueToStore(), adaptor.getDest(),
4021 vectorSizeFoldThreshold)) {
4022 return res;
4023 }
4024
4025 return inplaceFolded;
4026}
4027
4028//===----------------------------------------------------------------------===//
4029// InsertStridedSliceOp
4030//===----------------------------------------------------------------------===//
4031
4032void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4033 Value source, Value dest,
4034 ArrayRef<int64_t> offsets,
4035 ArrayRef<int64_t> strides) {
4036 result.addOperands({source, dest});
4037 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4038 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4039 result.addTypes(dest.getType());
4040 result.addAttribute(InsertStridedSliceOp::getOffsetsAttrName(result.name),
4041 offsetsAttr);
4042 result.addAttribute(InsertStridedSliceOp::getStridesAttrName(result.name),
4043 stridesAttr);
4044}
4045
4046// TODO: Should be moved to Tablegen ConfinedAttr attributes.
4047template <typename OpType>
4048static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
4049 ArrayAttr arrayAttr,
4051 StringRef attrName) {
4052 if (arrayAttr.size() > shape.size())
4053 return op.emitOpError("expected ")
4054 << attrName << " attribute of rank no greater than vector rank";
4055 return success();
4056}
4057
4058// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
4059// interval. If `halfOpen` is true then the admissible interval is [min, max).
4060// Otherwise, the admissible interval is [min, max].
4061template <typename OpType>
4062static LogicalResult
4064 int64_t max, StringRef attrName,
4065 bool halfOpen = true) {
4066 for (auto attr : arrayAttr) {
4067 auto val = llvm::cast<IntegerAttr>(attr).getInt();
4068 auto upper = max;
4069 if (!halfOpen)
4070 upper += 1;
4071 if (val < min || val >= upper)
4072 return op.emitOpError("expected ") << attrName << " to be confined to ["
4073 << min << ", " << upper << ")";
4074 }
4075 return success();
4076}
4077
4078// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
4079// interval. If `halfOpen` is true then the admissible interval is [min, max).
4080// Otherwise, the admissible interval is [min, max].
4081template <typename OpType>
4082static LogicalResult
4084 ArrayRef<int64_t> shape, StringRef attrName,
4085 bool halfOpen = true, int64_t min = 0) {
4086 for (auto [index, attrDimPair] :
4087 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
4088 int64_t val = llvm::cast<IntegerAttr>(std::get<0>(attrDimPair)).getInt();
4089 int64_t max = std::get<1>(attrDimPair);
4090 if (!halfOpen)
4091 max += 1;
4092 if (val < min || val >= max)
4093 return op.emitOpError("expected ")
4094 << attrName << " dimension " << index << " to be confined to ["
4095 << min << ", " << max << ")";
4096 }
4097 return success();
4098}
4099
4100// Returns true if, for all indices i = 0..shape.size()-1, val is in the
4101// [min, max} interval:
4102// val = `arrayAttr1[i]` + `arrayAttr2[i]`,
4103// If `halfOpen` is true then the admissible interval is [min, max). Otherwise,
4104// the admissible interval is [min, max].
4105template <typename OpType>
4107 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
4108 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
4109 bool halfOpen = true, int64_t min = 1) {
4110 assert(arrayAttr1.size() <= shape.size());
4111 assert(arrayAttr2.size() <= shape.size());
4112 for (auto [index, it] :
4113 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
4114 auto val1 = llvm::cast<IntegerAttr>(std::get<0>(it)).getInt();
4115 auto val2 = llvm::cast<IntegerAttr>(std::get<1>(it)).getInt();
4116 int64_t max = std::get<2>(it);
4117 if (!halfOpen)
4118 max += 1;
4119 if (val1 + val2 < 0 || val1 + val2 >= max)
4120 return op.emitOpError("expected sum(")
4121 << attrName1 << ", " << attrName2 << ") dimension " << index
4122 << " to be confined to [" << min << ", " << max << ")";
4123 }
4124 return success();
4125}
4126
4128 MLIRContext *context) {
4129 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
4130 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
4131 });
4132 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
4133}
4134
4135LogicalResult InsertStridedSliceOp::verify() {
4136 auto sourceVectorType = getSourceVectorType();
4137 auto destVectorType = getDestVectorType();
4138 auto offsets = getOffsetsAttr();
4139 auto strides = getStridesAttr();
4140 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
4141 return emitOpError(
4142 "expected offsets of same size as destination vector rank");
4143 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
4144 return emitOpError("expected strides of same size as source vector rank");
4145 if (sourceVectorType.getRank() > destVectorType.getRank())
4146 return emitOpError(
4147 "expected source rank to be no greater than destination rank");
4148
4149 auto sourceShape = sourceVectorType.getShape();
4150 auto destShape = destVectorType.getShape();
4151 SmallVector<int64_t, 4> sourceShapeAsDestShape(
4152 destShape.size() - sourceShape.size(), 0);
4153 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
4154 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
4155 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
4156 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
4157 offName)) ||
4158 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4159 /*max=*/1, stridesName,
4160 /*halfOpen=*/false)) ||
4162 *this, offsets,
4163 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
4164 offName, "source vector shape",
4165 /*halfOpen=*/false, /*min=*/1)))
4166 return failure();
4167
4168 unsigned rankDiff = destShape.size() - sourceShape.size();
4169 for (unsigned idx = 0; idx < sourceShape.size(); ++idx) {
4170 if (sourceVectorType.getScalableDims()[idx] !=
4171 destVectorType.getScalableDims()[idx + rankDiff]) {
4172 return emitOpError("mismatching scalable flags (at source vector idx=")
4173 << idx << ")";
4174 }
4175 if (sourceVectorType.getScalableDims()[idx]) {
4176 auto sourceSize = sourceShape[idx];
4177 auto destSize = destShape[idx + rankDiff];
4178 if (sourceSize != destSize) {
4179 return emitOpError("expected size at idx=")
4180 << idx
4181 << (" to match the corresponding base size from the input "
4182 "vector (")
4183 << sourceSize << (" vs ") << destSize << (")");
4184 }
4185 }
4186 }
4187
4188 return success();
4189}
4190
4191namespace {
4192/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
4193class FoldInsertStridedSliceSplat final
4194 : public OpRewritePattern<InsertStridedSliceOp> {
4195public:
4196 using Base::Base;
4197
4198 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4199 PatternRewriter &rewriter) const override {
4200
4201 auto dst = insertStridedSliceOp.getDest();
4202 auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
4203 if (!splat || getScalarSplatSource(dst) != splat)
4204 return failure();
4205
4206 rewriter.replaceOp(insertStridedSliceOp, dst);
4207 return success();
4208 }
4209};
4210
4211/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
4212/// to dst.
4213class FoldInsertStridedSliceOfExtract final
4214 : public OpRewritePattern<InsertStridedSliceOp> {
4215public:
4216 using Base::Base;
4217
4218 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
4219 PatternRewriter &rewriter) const override {
4220 auto extractStridedSliceOp =
4221 insertStridedSliceOp.getValueToStore()
4222 .getDefiningOp<vector::ExtractStridedSliceOp>();
4223
4224 if (!extractStridedSliceOp)
4225 return failure();
4226
4227 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
4228 return failure();
4229
4230 // Check if have the same strides and offsets.
4231 if (extractStridedSliceOp.getStrides() !=
4232 insertStridedSliceOp.getStrides() ||
4233 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
4234 return failure();
4235
4236 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
4237 return success();
4238 }
4239};
4240
4241// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
4242// ConstantOp.
4243class InsertStridedSliceConstantFolder final
4244 : public OpRewritePattern<InsertStridedSliceOp> {
4245public:
4246 using Base::Base;
4247
4248 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
4249 // unless the source vector constant has a single use.
4250 static constexpr int64_t vectorSizeFoldThreshold = 256;
4251
4252 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
4253 PatternRewriter &rewriter) const override {
4254 // Return if 'InsertOp' operand is not defined by a compatible vector
4255 // ConstantOp.
4256 TypedValue<VectorType> destVector = op.getDest();
4257 Attribute vectorDestCst;
4258 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
4259 return failure();
4260
4261 VectorType destTy = destVector.getType();
4262 if (destTy.isScalable())
4263 return failure();
4264
4265 // Make sure we do not create too many large constants.
4266 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
4267 !destVector.hasOneUse())
4268 return failure();
4269
4270 TypedValue<VectorType> sourceValue = op.getValueToStore();
4271 Attribute sourceCst;
4272 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
4273 return failure();
4274
4275 // TODO: Support poison.
4276 if (matchPattern(vectorDestCst, ub::m_Poison()) ||
4277 matchPattern(sourceCst, ub::m_Poison()))
4278 return failure();
4279
4280 // TODO: Handle non-unit strides when they become available.
4281 if (op.hasNonUnitStrides())
4282 return failure();
4283
4284 VectorType sliceVecTy = sourceValue.getType();
4285 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4286 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
4287 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
4288 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
4289
4290 // Calcualte the destination element indices by enumerating all slice
4291 // positions within the destination and linearizing them. The enumeration
4292 // order is lexicographic which yields a sequence of monotonically
4293 // increasing linearized position indices.
4294 // Because the destination may have higher dimensionality then the slice,
4295 // we keep track of two overlapping sets of positions and offsets.
4296 auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
4297 auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
4298 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
4299 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
4300 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
4301 MutableArrayRef<int64_t> currSlicePosition(
4302 currDestPosition.begin() + rankDifference, currDestPosition.end());
4303 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
4304 offsets.end());
4305 do {
4306 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
4307 assert(linearizedPosition < destTy.getNumElements() && "Invalid index");
4308 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&
4309 "Invalid slice element");
4310 newValues[linearizedPosition] = *sliceValuesIt;
4311 ++sliceValuesIt;
4312 } while (succeeded(
4313 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
4314
4315 auto newAttr = DenseElementsAttr::get(destTy, newValues);
4316 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
4317 return success();
4318 }
4319};
4320
4321} // namespace
4322
4323void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
4324 RewritePatternSet &results, MLIRContext *context) {
4325 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
4326 InsertStridedSliceConstantFolder>(context);
4327}
4328
4329OpFoldResult InsertStridedSliceOp::fold(FoldAdaptor adaptor) {
4330 if (getSourceVectorType() == getDestVectorType())
4331 return getValueToStore();
4332 return {};
4333}
4334
4335//===----------------------------------------------------------------------===//
4336// OuterProductOp
4337//===----------------------------------------------------------------------===//
4338
4339/// Build an op without mask, use the type of `acc` as the return type.
4340void OuterProductOp::build(OpBuilder &builder, OperationState &result,
4341 Value lhs, Value rhs, Value acc) {
4342 result.addOperands({lhs, rhs, acc});
4343 result.addTypes(acc.getType());
4344}
4345
4346void OuterProductOp::print(OpAsmPrinter &p) {
4347 p << " " << getLhs() << ", " << getRhs();
4348 if (getAcc()) {
4349 p << ", " << getAcc();
4350 p.printOptionalAttrDict((*this)->getAttrs());
4351 }
4352 p << " : " << getLhs().getType() << ", " << getRhs().getType();
4353}
4354
4355ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
4356 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
4357 Type tLHS, tRHS;
4358 if (parser.parseOperandList(operandsInfo) ||
4359 parser.parseOptionalAttrDict(result.attributes) ||
4360 parser.parseColonType(tLHS) || parser.parseComma() ||
4361 parser.parseType(tRHS))
4362 return failure();
4363 if (operandsInfo.size() < 2)
4364 return parser.emitError(parser.getNameLoc(),
4365 "expected at least 2 operands");
4366 VectorType vLHS = llvm::dyn_cast<VectorType>(tLHS);
4367 VectorType vRHS = llvm::dyn_cast<VectorType>(tRHS);
4368 if (!vLHS)
4369 return parser.emitError(parser.getNameLoc(),
4370 "expected vector type for operand #1");
4371
4372 VectorType resType;
4373 if (vRHS) {
4374 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
4375 vRHS.getScalableDims()[0]};
4376 resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
4377 vLHS.getElementType(), scalableDimsRes);
4378 } else {
4379 // Scalar RHS operand
4380 SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
4381 resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
4382 scalableDimsRes);
4383 }
4384
4385 if (!result.attributes.get(OuterProductOp::getKindAttrName(result.name))) {
4386 result.attributes.append(
4387 OuterProductOp::getKindAttrName(result.name),
4388 CombiningKindAttr::get(result.getContext(),
4389 OuterProductOp::getDefaultKind()));
4390 }
4391
4392 return failure(
4393 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
4394 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
4395 (operandsInfo.size() > 2 &&
4396 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
4397 parser.addTypeToList(resType, result.types));
4398}
4399
4400LogicalResult OuterProductOp::verify() {
4401 Type tRHS = getOperandTypeRHS();
4402 VectorType vLHS = getOperandVectorTypeLHS(),
4403 vRHS = llvm::dyn_cast<VectorType>(tRHS),
4404 vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
4405
4406 if (vLHS.getRank() != 1)
4407 return emitOpError("expected 1-d vector for operand #1");
4408
4409 if (vRHS) {
4410 // Proper OUTER operation.
4411 if (vRHS.getRank() != 1)
4412 return emitOpError("expected 1-d vector for operand #2");
4413 if (vRES.getRank() != 2)
4414 return emitOpError("expected 2-d vector result");
4415 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4416 return emitOpError("expected #1 operand dim to match result dim #1");
4417 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
4418 return emitOpError("expected #2 operand dim to match result dim #2");
4419 if (vLHS.isScalable() && !vRHS.isScalable()) {
4420 // This restriction reflects what's currently supported in terms of
4421 // scalable vectors. However, we could relax this if there's a use case.
4422 return emitOpError(
4423 "expected either both or only #2 operand dim to be scalable");
4424 }
4425 } else {
4426 // An AXPY operation.
4427 if (vRES.getRank() != 1)
4428 return emitOpError("expected 1-d vector result");
4429 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
4430 return emitOpError("expected #1 operand dim to match result dim #1");
4431 }
4432
4433 if (vACC && vACC != vRES)
4434 return emitOpError("expected operand #3 of same type as result type");
4435
4436 if (!getKindAttr()) {
4437 return emitOpError("expected 'kind' attribute of type CombiningKind (e.g. "
4438 "'vector.kind<add>')");
4439 }
4440
4441 // Verify supported combining kind.
4442 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
4443 return emitOpError("unsupported outerproduct type");
4444
4445 return success();
4446}
4447
4448// MaskableOpInterface methods.
4449
4450/// Returns the mask type expected by this operation. Mostly used for
4451/// verification purposes. It requires the operation to be vectorized."
4452Type OuterProductOp::getExpectedMaskType() {
4453 auto vecType = this->getResultVectorType();
4454 return VectorType::get(vecType.getShape(),
4455 IntegerType::get(vecType.getContext(), /*width=*/1),
4456 vecType.getScalableDims());
4457}
4458
4459//===----------------------------------------------------------------------===//
4460// ExtractStridedSliceOp
4461//===----------------------------------------------------------------------===//
4462
4463// Inference works as follows:
4464// 1. Add 'sizes' from prefix of dims in 'offsets'.
4465// 2. Add sizes from 'vectorType' for remaining dims.
4466// Scalable flags are inherited from 'vectorType'.
4467static Type inferStridedSliceOpResultType(VectorType vectorType,
4468 ArrayAttr offsets, ArrayAttr sizes,
4469 ArrayAttr strides) {
4470 assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
4472 shape.reserve(vectorType.getRank());
4473 unsigned idx = 0;
4474 for (unsigned e = offsets.size(); idx < e; ++idx)
4475 shape.push_back(llvm::cast<IntegerAttr>(sizes[idx]).getInt());
4476 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
4477 shape.push_back(vectorType.getShape()[idx]);
4478
4479 return VectorType::get(shape, vectorType.getElementType(),
4480 vectorType.getScalableDims());
4481}
4482
4483void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
4484 Value source, ArrayRef<int64_t> offsets,
4485 ArrayRef<int64_t> sizes,
4486 ArrayRef<int64_t> strides) {
4487 result.addOperands(source);
4488 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
4489 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
4490 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
4491 result.addTypes(
4492 inferStridedSliceOpResultType(llvm::cast<VectorType>(source.getType()),
4493 offsetsAttr, sizesAttr, stridesAttr));
4494 result.addAttribute(ExtractStridedSliceOp::getOffsetsAttrName(result.name),
4495 offsetsAttr);
4496 result.addAttribute(ExtractStridedSliceOp::getSizesAttrName(result.name),
4497 sizesAttr);
4498 result.addAttribute(ExtractStridedSliceOp::getStridesAttrName(result.name),
4499 stridesAttr);
4500}
4501
4502LogicalResult ExtractStridedSliceOp::verify() {
4503 auto type = getSourceVectorType();
4504 auto offsets = getOffsetsAttr();
4505 auto sizes = getSizesAttr();
4506 auto strides = getStridesAttr();
4507 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
4508 return emitOpError(
4509 "expected offsets, sizes and strides attributes of same size");
4510
4511 auto shape = type.getShape();
4512 auto offName = getOffsetsAttrName();
4513 auto sizesName = getSizesAttrName();
4514 auto stridesName = getStridesAttrName();
4515 if (failed(
4516 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
4517 failed(
4518 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
4519 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
4520 stridesName)) ||
4521 failed(
4522 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
4523 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
4524 /*halfOpen=*/false,
4525 /*min=*/1)) ||
4526 failed(isIntegerArrayAttrConfinedToRange(*this, strides, /*min=*/1,
4527 /*max=*/1, stridesName,
4528 /*halfOpen=*/false)) ||
4529 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
4530 shape, offName, sizesName,
4531 /*halfOpen=*/false)))
4532 return failure();
4533
4534 auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
4535 offsets, sizes, strides);
4536 if (getResult().getType() != resultType)
4537 return emitOpError("expected result type to be ") << resultType;
4538
4539 for (unsigned idx = 0; idx < sizes.size(); ++idx) {
4540 if (type.getScalableDims()[idx]) {
4541 auto inputDim = type.getShape()[idx];
4542 auto inputSize = llvm::cast<IntegerAttr>(sizes[idx]).getInt();
4543 if (inputDim != inputSize)
4544 return emitOpError("expected size at idx=")
4545 << idx
4546 << (" to match the corresponding base size from the input "
4547 "vector (")
4548 << inputSize << (" vs ") << inputDim << (")");
4549 }
4550 }
4551
4552 return success();
4553}
4554
4555// When the source of ExtractStrided comes from a chain of InsertStrided ops try
4556// to use the source of the InsertStrided ops if we can detect that the
4557// extracted vector is a subset of one of the vector inserted.
4558static LogicalResult
4559foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
4560 // Helper to extract integer out of ArrayAttr.
4561 auto getElement = [](ArrayAttr array, int idx) {
4562 return llvm::cast<IntegerAttr>(array[idx]).getInt();
4563 };
4564 ArrayAttr extractOffsets = op.getOffsets();
4565 ArrayAttr extractStrides = op.getStrides();
4566 ArrayAttr extractSizes = op.getSizes();
4567 auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
4568 while (insertOp) {
4569 if (op.getSourceVectorType().getRank() !=
4570 insertOp.getSourceVectorType().getRank())
4571 return failure();
4572 ArrayAttr insertOffsets = insertOp.getOffsets();
4573 ArrayAttr insertStrides = insertOp.getStrides();
4574 // If the rank of extract is greater than the rank of insert, we are likely
4575 // extracting a partial chunk of the vector inserted.
4576 if (extractOffsets.size() > insertOffsets.size())
4577 return failure();
4578 bool patialoverlap = false;
4579 bool disjoint = false;
4580 SmallVector<int64_t, 4> offsetDiffs;
4581 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
4582 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
4583 return failure();
4584 int64_t start = getElement(insertOffsets, dim);
4585 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
4586 int64_t offset = getElement(extractOffsets, dim);
4587 int64_t size = getElement(extractSizes, dim);
4588 // Check if the start of the extract offset is in the interval inserted.
4589 if (start <= offset && offset < end) {
4590 // If the extract interval overlaps but is not fully included we may
4591 // have a partial overlap that will prevent any folding.
4592 if (offset + size > end)
4593 patialoverlap = true;
4594 offsetDiffs.push_back(offset - start);
4595 continue;
4596 }
4597 disjoint = true;
4598 break;
4599 }
4600 // The extract element chunk is a subset of the insert element.
4601 if (!disjoint && !patialoverlap) {
4602 op.setOperand(insertOp.getValueToStore());
4603 // OpBuilder is only used as a helper to build an I64ArrayAttr.
4604 OpBuilder b(op.getContext());
4605 op.setOffsetsAttr(b.getI64ArrayAttr(offsetDiffs));
4606 return success();
4607 }
4608 // If the chunk extracted is disjoint from the chunk inserted, keep looking
4609 // in the insert chain.
4610 if (disjoint)
4611 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
4612 else {
4613 // The extracted vector partially overlap the inserted vector, we cannot
4614 // fold.
4615 return failure();
4616 }
4617 }
4618 return failure();
4619}
4620
4621// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4622static OpFoldResult
4624 Attribute foldInput) {
4625
4626 auto dense = llvm::dyn_cast_if_present<DenseElementsAttr>(foldInput);
4627 if (!dense)
4628 return {};
4629
4630 // TODO: Handle non-unit strides when they become available.
4631 if (op.hasNonUnitStrides())
4632 return {};
4633
4634 VectorType sourceVecTy = op.getSourceVectorType();
4635 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
4636 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
4637
4638 VectorType sliceVecTy = op.getType();
4639 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
4640 int64_t rank = sliceVecTy.getRank();
4641
4642 // Expand offsets and sizes to match the vector rank.
4643 SmallVector<int64_t, 4> offsets(rank, 0);
4644 copy(getI64SubArray(op.getOffsets()), offsets.begin());
4645
4646 SmallVector<int64_t, 4> sizes(sourceShape);
4647 copy(getI64SubArray(op.getSizes()), sizes.begin());
4648
4649 // Calculate the slice elements by enumerating all slice positions and
4650 // linearizing them. The enumeration order is lexicographic which yields a
4651 // sequence of monotonically increasing linearized position indices.
4652 const auto denseValuesBegin = dense.value_begin<Attribute>();
4653 SmallVector<Attribute> sliceValues;
4654 sliceValues.reserve(sliceVecTy.getNumElements());
4655 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
4656 do {
4657 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
4658 assert(linearizedPosition < sourceVecTy.getNumElements() &&
4659 "Invalid index");
4660 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
4661 } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
4662
4663 assert(static_cast<int64_t>(sliceValues.size()) ==
4664 sliceVecTy.getNumElements() &&
4665 "Invalid number of slice elements");
4666 return DenseElementsAttr::get(sliceVecTy, sliceValues);
4667}
4668
4669OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
4670 if (getSourceVectorType() == getResult().getType())
4671 return getSource();
4672 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
4673 return getResult();
4674
4675 // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
4676 if (auto splat =
4677 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
4678 return DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
4679
4680 // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
4681 return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
4682}
4683
4684void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
4685 populateFromInt64AttrArray(getOffsets(), results);
4686}
4687
4688namespace {
4689
4690// Pattern to rewrite nested ExtractStridedSliceOp into a single one.
4691//
4692// Example:
4693//
4694// %0 = vector.extract_strided_slice %arg0
4695// {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]}
4696// : vector<4x8x16xf32> to vector<3x4x16xf32>
4697// %1 = vector.extract_strided_slice %0
4698// {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
4699// : vector<3x4x16xf32> to vector<2x2x16xf32>
4700//
4701// to
4702//
4703// %1 = vector.extract_strided_slice %arg0
4704// {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
4705// : vector<4x8x16xf32> to vector<2x2x16xf32>
4706class StridedSliceFolder final
4707 : public OpRewritePattern<ExtractStridedSliceOp> {
4708public:
4709 using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
4710
4711 LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
4712 PatternRewriter &rewriter) const override {
4713 auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
4714 if (!firstOp)
4715 return failure();
4716
4717 if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
4718 return failure();
4719
4720 SmallVector<int64_t> firstOffsets = getI64SubArray(firstOp.getOffsets());
4721 SmallVector<int64_t> firstSizes = getI64SubArray(firstOp.getSizes());
4722 SmallVector<int64_t> secondOffsets = getI64SubArray(secondOp.getOffsets());
4723 SmallVector<int64_t> secondSizes = getI64SubArray(secondOp.getSizes());
4724
4725 unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
4726 SmallVector<int64_t> combinedOffsets(newRank, 0);
4727 SmallVector<int64_t> combinedSizes(newRank);
4728 ArrayRef<int64_t> firstSourceShape =
4729 firstOp.getSourceVectorType().getShape();
4730 for (unsigned i = 0; i < newRank; ++i) {
4731 int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
4732 int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
4733 combinedOffsets[i] = off1 + off2;
4734
4735 if (i < secondSizes.size()) {
4736 combinedSizes[i] = secondSizes[i];
4737 } else if (i < firstSizes.size()) {
4738 combinedSizes[i] = firstSizes[i];
4739 } else {
4740 combinedSizes[i] = firstSourceShape[i];
4741 }
4742 }
4743
4744 SmallVector<int64_t> combinedStrides(newRank, 1);
4745 rewriter.replaceOpWithNewOp<ExtractStridedSliceOp>(
4746 secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
4747 combinedStrides);
4748 return success();
4749 }
4750};
4751
4752// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4753// CreateMaskOp.
4754//
4755// Example:
4756//
4757// %mask = vector.create_mask %ub : vector<16xi1>
4758// %slice = vector.extract_strided_slice [%offset] [8] [1]
4759//
4760// to
4761//
4762// %new_ub = arith.subi %ub, %offset
4763// %mask = vector.create_mask %new_ub : vector<8xi1>
4764class StridedSliceCreateMaskFolder final
4765 : public OpRewritePattern<ExtractStridedSliceOp> {
4766 using Base::Base;
4767
4768public:
4769 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4770 PatternRewriter &rewriter) const override {
4771 Location loc = extractStridedSliceOp.getLoc();
4772 // Return if 'extractStridedSliceOp' operand is not defined by a
4773 // CreateMaskOp.
4774 auto createMaskOp =
4775 extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
4776 if (!createMaskOp)
4777 return failure();
4778 // Return if 'extractStridedSliceOp' has non-unit strides.
4779 if (extractStridedSliceOp.hasNonUnitStrides())
4780 return failure();
4781 // Gather constant mask dimension sizes.
4782 SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
4783 // Gather strided slice offsets and sizes.
4784 SmallVector<int64_t> sliceOffsets;
4785 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4786 sliceOffsets);
4787 SmallVector<int64_t> sliceSizes;
4788 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4789
4790 // Compute slice of vector mask region.
4791 SmallVector<Value> sliceMaskDimSizes;
4792 sliceMaskDimSizes.reserve(maskDimSizes.size());
4793 // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4794 // only iterate on the leading dim sizes. The tail accounts for the
4795 // remaining dim sizes.
4796 for (auto [maskDimSize, sliceOffset, sliceSize] :
4797 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4798 // No need to clamp on min/max values, because create_mask has clamping
4799 // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4800 // greater than the vector dim size.
4801 IntegerAttr offsetAttr =
4802 rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
4803 Value offset = arith::ConstantOp::create(rewriter, loc, offsetAttr);
4804 Value sliceMaskDimSize =
4805 arith::SubIOp::create(rewriter, loc, maskDimSize, offset);
4806 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4807 }
4808 // Add unchanged dimensions.
4809 llvm::append_range(
4810 sliceMaskDimSizes,
4811 llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
4812 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4813 // region.
4814 rewriter.replaceOpWithNewOp<CreateMaskOp>(
4815 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4816 sliceMaskDimSizes);
4817 return success();
4818 }
4819};
4820
4821// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
4822// ConstantMaskOp.
4823class StridedSliceConstantMaskFolder final
4824 : public OpRewritePattern<ExtractStridedSliceOp> {
4825public:
4826 using Base::Base;
4827
4828 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
4829 PatternRewriter &rewriter) const override {
4830 // Return if 'extractStridedSliceOp' operand is not defined by a
4831 // ConstantMaskOp.
4832 auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
4833 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
4834 if (!constantMaskOp)
4835 return failure();
4836 // Return if 'extractStridedSliceOp' has non-unit strides.
4837 if (extractStridedSliceOp.hasNonUnitStrides())
4838 return failure();
4839 // Gather constant mask dimension sizes.
4840 ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
4841 // Gather strided slice offsets and sizes.
4842 SmallVector<int64_t> sliceOffsets;
4843 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
4844 sliceOffsets);
4845 SmallVector<int64_t> sliceSizes;
4846 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
4847
4848 // Compute slice of vector mask region.
4849 SmallVector<int64_t> sliceMaskDimSizes;
4850 sliceMaskDimSizes.reserve(maskDimSizes.size());
4851 for (auto [maskDimSize, sliceOffset, sliceSize] :
4852 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
4853 int64_t sliceMaskDimSize = std::max(
4854 static_cast<int64_t>(0),
4855 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
4856 sliceMaskDimSizes.push_back(sliceMaskDimSize);
4857 }
4858 // Add unchanged dimensions.
4859 if (sliceMaskDimSizes.size() < maskDimSizes.size())
4860 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
4861 sliceMaskDimSizes.push_back(maskDimSizes[i]);
4862 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
4863 // region is a conjunction of mask dim intervals).
4864 if (llvm::is_contained(sliceMaskDimSizes, 0))
4865 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
4866
4867 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
4868 // region.
4869 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4870 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
4871 sliceMaskDimSizes);
4872 return success();
4873 }
4874};
4875
4876// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
4877// BroadcastOp(ExtractStrideSliceOp).
4878class StridedSliceBroadcast final
4879 : public OpRewritePattern<ExtractStridedSliceOp> {
4880public:
4881 using Base::Base;
4882
4883 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4884 PatternRewriter &rewriter) const override {
4885 auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
4886 if (!broadcast)
4887 return failure();
4888 auto srcVecType =
4889 llvm::dyn_cast<VectorType>(broadcast.getSource().getType());
4890 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
4891 auto dstVecType = llvm::cast<VectorType>(op.getType());
4892 unsigned dstRank = dstVecType.getRank();
4893 unsigned rankDiff = dstRank - srcRank;
4894 // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
4895 // (n -> m with n > m). If they are originally both broadcasted *and*
4896 // sliced, this can be simplified to just broadcasting.
4897 bool needsSlice = false;
4898 for (unsigned i = 0; i < srcRank; i++) {
4899 if (srcVecType.getDimSize(i) != 1 &&
4900 srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
4901 needsSlice = true;
4902 break;
4903 }
4904 }
4905 Value source = broadcast.getSource();
4906 if (needsSlice) {
4907 SmallVector<int64_t> offsets =
4908 getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
4909 SmallVector<int64_t> sizes =
4910 getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
4911 for (unsigned i = 0; i < srcRank; i++) {
4912 if (srcVecType.getDimSize(i) == 1) {
4913 // In case this dimension was broadcasted *and* sliced, the offset
4914 // and size need to be updated now that there is no broadcast before
4915 // the slice.
4916 offsets[i] = 0;
4917 sizes[i] = 1;
4918 }
4919 }
4920 source = ExtractStridedSliceOp::create(
4921 rewriter, op->getLoc(), source, offsets, sizes,
4922 getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
4923 }
4924 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
4925 return success();
4926 }
4927};
4928
4929/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
4930class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4931public:
4932 using Base::Base;
4933
4934 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4935 PatternRewriter &rewriter) const override {
4936
4937 Value splat = getScalarSplatSource(op.getSource());
4938 if (!splat)
4939 return failure();
4940 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
4941 return success();
4942 }
4943};
4944
4945/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
4946/// slice is contiguous, into extract and shape_cast.
4947///
4948/// Example:
4949/// Before:
4950/// %1 = vector.extract_strided_slice %arg0 {
4951/// offsets = [0, 0, 0, 0, 0],
4952/// sizes = [1, 1, 1, 1, 8],
4953/// strides = [1, 1, 1, 1, 1]
4954/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
4955/// After:
4956/// %0 = vector.extract %arg0[0, 0, 0, 0]
4957/// : vector<8xi8> from vector<8x1x1x2x8xi8>
4958/// %1 = vector.shape_cast %0
4959/// : vector<8xi8> to vector<1x1x1x1x8xi8>
4960///
4961class ContiguousExtractStridedSliceToExtract final
4962 : public OpRewritePattern<ExtractStridedSliceOp> {
4963public:
4964 using Base::Base;
4965
4966 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
4967 PatternRewriter &rewriter) const override {
4968 if (op.hasNonUnitStrides())
4969 return failure();
4970 Value source = op.getOperand();
4971 auto sourceType = cast<VectorType>(source.getType());
4972 if (sourceType.isScalable() || sourceType.getRank() == 0)
4973 return failure();
4974
4975 // Compute the number of offsets to pass to ExtractOp::build. That is the
4976 // difference between the source rank and the desired slice rank. We walk
4977 // the dimensions from innermost out, and stop when the next slice dimension
4978 // is not full-size.
4979 SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
4980 int numOffsets;
4981 for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
4982 if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
4983 break;
4984 }
4985
4986 // If the created extract op would have no offsets, then this whole
4987 // extract_strided_slice is the identity and should have been handled by
4988 // other canonicalizations.
4989 if (numOffsets == 0)
4990 return failure();
4991
4992 // If not even the inner-most dimension is full-size, this op can't be
4993 // rewritten as an ExtractOp.
4994 if (numOffsets == sourceType.getRank() &&
4995 static_cast<int>(sizes.size()) == sourceType.getRank())
4996 return failure();
4997
4998 // The outer dimensions must have unit size.
4999 for (int i = 0; i < numOffsets; ++i) {
5000 if (sizes[i] != 1)
5001 return failure();
5002 }
5003
5004 // Avoid generating slices that have leading unit dimensions. The shape_cast
5005 // op that we create below would take bad generic fallback patterns
5006 // (ShapeCastOpRewritePattern).
5007 while (numOffsets < static_cast<int>(sizes.size()) - 1 &&
5008 sizes[numOffsets] == 1) {
5009 ++numOffsets;
5010 }
5011
5012 SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
5013 auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
5014 Value extract = vector::ExtractOp::create(rewriter, op->getLoc(), source,
5015 extractOffsets);
5016 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
5017 return success();
5018 }
5019};
5020
5021} // namespace
5022
5023void ExtractStridedSliceOp::getCanonicalizationPatterns(
5024 RewritePatternSet &results, MLIRContext *context) {
5025 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
5026 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
5027 results.add<StridedSliceFolder, StridedSliceCreateMaskFolder,
5028 StridedSliceConstantMaskFolder, StridedSliceBroadcast,
5029 StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
5030 context);
5031}
5032
5033//===----------------------------------------------------------------------===//
5034// TransferReadOp
5035//===----------------------------------------------------------------------===//
5036
5037/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
5038/// If `padding` is null, a poison value is used.
5039void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5040 VectorType vectorType, Value source,
5041 ValueRange indices, std::optional<Value> padding,
5042 AffineMapAttr permutationMapAttr,
5043 /*optional*/ ArrayAttr inBoundsAttr) {
5044
5045 Type elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5046 if (!padding)
5047 padding = ub::PoisonOp::create(builder, result.location, elemType);
5048 // Delegate to the most general builder (see
5049 // `mlir/Dialect/Vector/IR/VectorOps.cpp.inc`)
5050 build(builder, result, vectorType, source, indices, permutationMapAttr,
5051 *padding, /*mask=*/Value(), inBoundsAttr);
5052}
5053
5054/// 2. Builder that sets padding to zero and an empty mask (variant without
5055/// attrs).
5056/// If `padding` is null, a poison value is used.
5057/// If `permutationMap` is null, a minor identity map is used.
5058/// If `inBounds` is null, an empty mask is used.
5059void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5060 VectorType vectorType, Value source,
5061 ValueRange indices, std::optional<Value> padding,
5062 AffineMap permutationMap,
5063 std::optional<ArrayRef<bool>> inBounds) {
5064 if (!permutationMap)
5065 permutationMap = getTransferMinorIdentityMap(
5066 llvm::cast<ShapedType>(source.getType()), vectorType);
5067 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5068 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
5069 ? builder.getBoolArrayAttr(inBounds.value())
5070 : builder.getBoolArrayAttr(
5071 SmallVector<bool>(vectorType.getRank(), false));
5072 // Delegate to Builder 1
5073 build(builder, result, vectorType, source, indices, padding,
5074 permutationMapAttr, inBoundsAttr);
5075}
5076
5077/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
5078/// If `padding` is null, a poison value is used.
5079/// If `inBounds` is null, an empty mask is used.
5080void TransferReadOp::build(OpBuilder &builder, OperationState &result,
5081 VectorType vectorType, Value source,
5082 ValueRange indices, std::optional<Value> padding,
5083 std::optional<ArrayRef<bool>> inBounds) {
5084 // Delegate to Builder 2
5085 build(builder, result, vectorType, source, indices, padding,
5086 /*permutationMap=*/AffineMap(), inBounds);
5087}
5088
5089template <typename EmitFun>
5090static LogicalResult verifyPermutationMap(AffineMap permutationMap,
5091 EmitFun emitOpError) {
5092 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
5093 for (auto expr : permutationMap.getResults()) {
5094 auto dim = dyn_cast<AffineDimExpr>(expr);
5095 auto zero = dyn_cast<AffineConstantExpr>(expr);
5096 if (zero) {
5097 if (zero.getValue() != 0) {
5098 return emitOpError(
5099 "requires a projected permutation_map (at most one dim or the zero "
5100 "constant can appear in each result)");
5101 }
5102 continue;
5103 }
5104 if (!dim) {
5105 return emitOpError("requires a projected permutation_map (at most one "
5106 "dim or the zero constant can appear in each result)");
5107 }
5108 if (seen[dim.getPosition()]) {
5109 return emitOpError(
5110 "requires a permutation_map that is a permutation (found one dim "
5111 "used more than once)");
5112 }
5113 seen[dim.getPosition()] = true;
5114 }
5115 return success();
5116}
5117
5118static LogicalResult
5119verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
5120 VectorType vectorType, VectorType maskType,
5121 VectorType inferredMaskType, AffineMap permutationMap,
5122 ArrayAttr inBounds) {
5123 if (op->hasAttr("masked")) {
5124 return op->emitOpError("masked attribute has been removed. "
5125 "Use in_bounds instead.");
5126 }
5127
5128 if (!llvm::isa<MemRefType, RankedTensorType>(shapedType))
5129 return op->emitOpError(
5130 "requires source to be a memref or ranked tensor type");
5131
5132 auto elementType = shapedType.getElementType();
5133 DataLayout dataLayout = DataLayout::closest(op);
5134 if (auto vectorElementType = llvm::dyn_cast<VectorType>(elementType)) {
5135 // Memref or tensor has vector element type.
5136 unsigned sourceVecSize =
5137 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
5138 vectorElementType.getShape().back();
5139 unsigned resultVecSize =
5140 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
5141 vectorType.getShape().back();
5142 if (resultVecSize % sourceVecSize != 0)
5143 return op->emitOpError(
5144 "requires the bitwidth of the minor 1-D vector to be an integral "
5145 "multiple of the bitwidth of the minor 1-D vector of the source");
5146
5147 unsigned sourceVecEltRank = vectorElementType.getRank();
5148 unsigned resultVecRank = vectorType.getRank();
5149 if (sourceVecEltRank > resultVecRank)
5150 return op->emitOpError(
5151 "requires source vector element and vector result ranks to match.");
5152 unsigned rankOffset = resultVecRank - sourceVecEltRank;
5153 // Check that permutation map results match 'rankOffset' of vector type.
5154 if (permutationMap.getNumResults() != rankOffset)
5155 return op->emitOpError("requires a permutation_map with result dims of "
5156 "the same rank as the vector type");
5157
5158 if (maskType)
5159 return op->emitOpError("does not support masks with vector element type");
5160 } else {
5161 // Memref or tensor has scalar element type.
5162 unsigned minorSize =
5163 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
5164 unsigned resultVecSize =
5165 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
5166 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
5167 return op->emitOpError(
5168 "requires the bitwidth of the minor 1-D vector to be an integral "
5169 "multiple of the bitwidth of the source element type");
5170
5171 // Check that permutation map results match rank of vector type.
5172 if (permutationMap.getNumResults() != vectorType.getRank())
5173 return op->emitOpError("requires a permutation_map with result dims of "
5174 "the same rank as the vector type");
5175 }
5176
5177 if (permutationMap.getNumSymbols() != 0)
5178 return op->emitOpError("requires permutation_map without symbols");
5179
5180 if (permutationMap.getNumInputs() != shapedType.getRank())
5181 return op->emitOpError("requires a permutation_map with input dims of the "
5182 "same rank as the source type");
5183
5184 if (maskType && maskType != inferredMaskType)
5185 return op->emitOpError("inferred mask type (")
5186 << inferredMaskType << ") and mask operand type (" << maskType
5187 << ") don't match";
5188
5189 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
5190 return op->emitOpError("expects the in_bounds attr of same rank "
5191 "as permutation_map results: ")
5192 << AffineMapAttr::get(permutationMap)
5193 << " vs inBounds of size: " << inBounds.size();
5194
5195 return success();
5196}
5197
5198static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
5199 SmallVector<StringRef, 3> elidedAttrs;
5200 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
5201 if (op.getPermutationMap().isMinorIdentity())
5202 elidedAttrs.push_back(op.getPermutationMapAttrName());
5203 // Elide in_bounds attribute if all dims are out-of-bounds.
5204 if (llvm::none_of(op.getInBoundsValues(), [](bool b) { return b; }))
5205 elidedAttrs.push_back(op.getInBoundsAttrName());
5206 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
5207}
5208
5209void TransferReadOp::print(OpAsmPrinter &p) {
5210 p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
5211 if (getMask())
5212 p << ", " << getMask();
5213 printTransferAttrs(p, *this);
5214 p << " : " << getShapedType() << ", " << getVectorType();
5215}
5216
5217VectorType mlir::vector::inferTransferOpMaskType(VectorType vecType,
5218 AffineMap permMap) {
5219 auto i1Type = IntegerType::get(permMap.getContext(), 1);
5220 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
5221 assert(invPermMap && "Inversed permutation map couldn't be computed");
5222 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
5223
5224 // The MaskOp specification doesn't support 0-D vectors at the moment. Turn a
5225 // 0-D mask into a single-element 1-D mask.
5226 if (maskShape.empty())
5227 maskShape.push_back(1);
5228
5229 SmallVector<bool> scalableDims =
5230 applyPermutationMap(invPermMap, vecType.getScalableDims());
5231
5232 return VectorType::get(maskShape, i1Type, scalableDims);
5233}
5234
5235ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
5236 auto &builder = parser.getBuilder();
5237 SMLoc typesLoc;
5243 // Parsing with support for paddingValue.
5244 if (parser.parseOperand(sourceInfo) ||
5246 parser.parseComma() || parser.parseOperand(paddingInfo))
5247 return failure();
5248 ParseResult hasMask = parser.parseOptionalComma();
5249 if (hasMask.succeeded()) {
5250 if (parser.parseOperand(maskInfo))
5251 return failure();
5252 }
5253 if (parser.parseOptionalAttrDict(result.attributes) ||
5254 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5255 return failure();
5256 if (types.size() != 2)
5257 return parser.emitError(typesLoc, "requires two types");
5258 auto indexType = builder.getIndexType();
5259 auto shapedType = llvm::dyn_cast<ShapedType>(types[0]);
5260 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5261 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5262 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
5263 if (!vectorType)
5264 return parser.emitError(typesLoc, "requires vector type");
5265 auto permMapAttrName = TransferReadOp::getPermutationMapAttrName(result.name);
5266 Attribute permMapAttr = result.attributes.get(permMapAttrName);
5267 AffineMap permMap;
5268 if (!permMapAttr) {
5269 if (shapedType.getRank() <
5270 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5271 return parser.emitError(typesLoc,
5272 "expected a custom permutation_map when "
5273 "rank(source) != rank(destination)");
5274 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5275 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5276 } else {
5277 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5278 }
5279 auto inBoundsAttrName = TransferReadOp::getInBoundsAttrName(result.name);
5280 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5281 if (!inBoundsAttr) {
5282 result.addAttribute(inBoundsAttrName,
5283 builder.getBoolArrayAttr(
5284 SmallVector<bool>(permMap.getNumResults(), false)));
5285 }
5286 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5287 parser.resolveOperands(indexInfo, indexType, result.operands) ||
5288 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
5289 result.operands))
5290 return failure();
5291 if (hasMask.succeeded()) {
5292 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5293 return parser.emitError(
5294 maskInfo.location, "does not support masks with vector element type");
5295 if (vectorType.getRank() != permMap.getNumResults()) {
5296 return parser.emitError(typesLoc,
5297 "expected the same rank for the vector and the "
5298 "results of the permutation map");
5299 }
5300 // Instead of adding the mask type as an op type, compute it based on the
5301 // vector type and the permutation map (to keep the type signature small).
5302 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5303 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5304 return failure();
5305 }
5306 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
5307 builder.getDenseI32ArrayAttr(
5308 {1, static_cast<int32_t>(indexInfo.size()), 1,
5309 static_cast<int32_t>(hasMask.succeeded())}));
5310 return parser.addTypeToList(vectorType, result.types);
5311}
5312
5313LogicalResult TransferReadOp::verify() {
5314 // Consistency of elemental types in source and vector.
5315 ShapedType shapedType = getShapedType();
5316 VectorType vectorType = getVectorType();
5317 VectorType maskType = getMaskType();
5318 auto paddingType = getPadding().getType();
5319 auto permutationMap = getPermutationMap();
5320 VectorType inferredMaskType =
5321 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5322 : VectorType();
5323 auto sourceElementType = shapedType.getElementType();
5324
5325 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
5326 return emitOpError("requires ") << shapedType.getRank() << " indices";
5327
5328 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5329 shapedType, vectorType, maskType,
5330 inferredMaskType, permutationMap, getInBounds())))
5331 return failure();
5332
5333 if (auto sourceVectorElementType =
5334 llvm::dyn_cast<VectorType>(sourceElementType)) {
5335 // Source has vector element type.
5336 // Check that 'sourceVectorElementType' and 'paddingType' types match.
5337 if (sourceVectorElementType != paddingType)
5338 return emitOpError(
5339 "requires source element type and padding type to match.");
5340
5341 } else {
5342 // Check that 'paddingType' is valid to store in a vector type.
5343 if (!VectorType::isValidElementType(paddingType))
5344 return emitOpError("requires valid padding vector elemental type");
5345
5346 // Check that padding type and vector element types match.
5347 if (paddingType != sourceElementType)
5348 return emitOpError(
5349 "requires formal padding and source of the same elemental type");
5350 }
5351
5352 return verifyPermutationMap(permutationMap,
5353 [&](Twine t) { return emitOpError(t); });
5354}
5355
5356// MaskableOpInterface methods.
5357
5358/// Returns the mask type expected by this operation. Mostly used for
5359/// verification purposes. It requires the operation to be vectorized."
5360Type TransferReadOp::getExpectedMaskType() {
5361 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5362}
5363
5364//===----------------------------------------------------------------------===//
5365// TransferReadOp: VectorTransferOpInterface methods.
5366//===----------------------------------------------------------------------===//
5367VectorType TransferReadOp::getVectorType() {
5368 return cast<VectorType>(getVector().getType());
5369}
5370
5371template <typename TransferOp>
5372static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
5373 // TODO: support more aggressive createOrFold on:
5374 // op.getIndices()[indicesIdx] + vectorType < dim(op.getSource(), indicesIdx)
5375 if (op.getShapedType().isDynamicDim(indicesIdx))
5376 return false;
5377 Value index = op.getIndices()[indicesIdx];
5378 std::optional<int64_t> cstOp = getConstantIntValue(index);
5379 if (!cstOp.has_value())
5380 return false;
5381
5382 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
5383 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
5384
5385 return cstOp.value() + vectorSize <= sourceSize;
5386}
5387
5388template <typename TransferOp>
5389static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
5390 // TODO: support 0-d corner case.
5391 // TODO: Be less conservative.
5392 if (op.getTransferRank() == 0)
5393 return failure();
5394 AffineMap permutationMap = op.getPermutationMap();
5395 bool changed = false;
5396 SmallVector<bool, 4> newInBounds;
5397 newInBounds.reserve(op.getTransferRank());
5398 // Idxs of non-bcast dims - used when analysing bcast dims.
5399 SmallVector<unsigned> nonBcastDims;
5400
5401 // 1. Process non-broadcast dims
5402 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
5403 // 1.1. Already marked as in-bounds, nothing to see here.
5404 if (op.isDimInBounds(i)) {
5405 newInBounds.push_back(true);
5406 continue;
5407 }
5408 // 1.2. Currently out-of-bounds, check whether we can statically determine
5409 // it is inBounds.
5410 bool inBounds = false;
5411 auto dimExpr = dyn_cast<AffineDimExpr>(permutationMap.getResult(i));
5412 if (dimExpr) {
5413 inBounds = isInBounds(op, /*resultIdx=*/i,
5414 /*indicesIdx=*/dimExpr.getPosition());
5415 nonBcastDims.push_back(i);
5416 }
5417
5418 newInBounds.push_back(inBounds);
5419 // We commit the pattern if it is "more inbounds".
5420 changed |= inBounds;
5421 }
5422
5423 // 2. Handle broadcast dims
5424 // If all non-broadcast dims are "in bounds", then all bcast dims should be
5425 // "in bounds" as well.
5426 bool allNonBcastDimsInBounds = llvm::all_of(
5427 nonBcastDims, [&newInBounds](unsigned idx) { return newInBounds[idx]; });
5428 if (allNonBcastDimsInBounds) {
5429 for (size_t idx : permutationMap.getBroadcastDims()) {
5430 changed |= !newInBounds[idx];
5431 newInBounds[idx] = true;
5432 }
5433 }
5434
5435 if (!changed)
5436 return failure();
5437 // OpBuilder is only used as a helper to build an I64ArrayAttr.
5438 OpBuilder b(op.getContext());
5439 op.setInBoundsAttr(b.getBoolArrayAttr(newInBounds));
5440 return success();
5441}
5442
5443template <typename TransferOp>
5444static LogicalResult foldTransferFullMask(TransferOp op) {
5445 auto mask = op.getMask();
5446 if (!mask)
5447 return failure();
5448
5450 return failure();
5451
5452 op.getMaskMutable().clear();
5453 return success();
5454}
5455
5456/// When the vector type is `vector<1xT>`, the permutation map is irrelevant:
5457/// the single vector lane always has iteration offset 0, so the element is at
5458/// `indices` regardless of which source dimension the map points at. Replace
5459/// with the minor identity to unblock lowering to vector.load / vector.store.
5460template <typename TransferOp>
5461static LogicalResult foldSize1TransferPermutationMap(TransferOp op) {
5462 VectorType vecType = op.getVectorType();
5463 if (vecType.getRank() != 1 || vecType.getShape()[0] != 1 ||
5464 vecType.isScalable())
5465 return failure();
5466
5467 AffineMap map = op.getPermutationMap();
5468 if (map.isMinorIdentity())
5469 return failure();
5470
5471 int64_t srcRank = op.getShapedType().getRank();
5472 if (srcRank < 1)
5473 return failure();
5474
5475 AffineMap minorIdentity =
5476 AffineMap::getMinorIdentityMap(srcRank, 1, op.getContext());
5477 op.setPermutationMapAttr(AffineMapAttr::get(minorIdentity));
5478 return success();
5479}
5480
5481/// ```
5482/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
5483/// : vector<1x4xf32>, tensor<4x4xf32>
5484/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
5485/// : tensor<4x4xf32>, vector<1x4xf32>
5486/// ```
5487/// -> Folds into
5488/// ```
5489/// %v0
5490/// ```
5491static Value foldRAW(TransferReadOp readOp) {
5492 if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
5493 return {};
5494 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5495 while (defWrite) {
5496 if (checkSameValueRAW(defWrite, readOp))
5497 return defWrite.getVector();
5499 cast<VectorTransferOpInterface>(defWrite.getOperation()),
5500 cast<VectorTransferOpInterface>(readOp.getOperation())))
5501 break;
5502 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
5503 }
5504 return {};
5505}
5506
5507OpFoldResult TransferReadOp::fold(FoldAdaptor) {
5508 if (Value vec = foldRAW(*this))
5509 return vec;
5510 /// transfer_read(memrefcast) -> transfer_read
5511 if (succeeded(foldTransferInBoundsAttribute(*this)))
5512 return getResult();
5513 if (succeeded(foldTransferFullMask(*this)))
5514 return getResult();
5515 if (succeeded(foldSize1TransferPermutationMap(*this)))
5516 return getResult();
5517 if (succeeded(memref::foldMemRefCast(*this)))
5518 return getResult();
5519 if (succeeded(tensor::foldTensorCast(*this)))
5520 return getResult();
5521 return OpFoldResult();
5522}
5523
5524std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
5525 return llvm::to_vector<4>(getVectorType().getShape());
5526}
5527
5528void TransferReadOp::getEffects(
5529 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5530 &effects) {
5531 if (llvm::isa<MemRefType>(getShapedType()))
5532 effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
5533 SideEffects::DefaultResource::get());
5534}
5535
5536Speculation::Speculatability TransferReadOp::getSpeculatability() {
5537 if (hasPureTensorSemantics())
5540}
5541
5542/// Given a projected permutation, inverse an affine map, making the unused dims
5543/// 0 in the result.
5544static AffineMap inverseWithUnusedDims(AffineMap map) {
5545 assert(map.isProjectedPermutation() &&
5546 "expected a projected permutation map");
5547 SmallVector<AffineExpr> results(map.getNumInputs(),
5549 for (auto [idx, result] : llvm::enumerate(map.getResults())) {
5550 // We should only have dim exprs because this is a projected permutation.
5551 int64_t pos = cast<AffineDimExpr>(result).getPosition();
5552 results[pos] = getAffineDimExpr(idx, map.getContext());
5553 }
5554 return AffineMap::get(/*dimCount=*/map.getNumResults(), /*symbolCount=*/0,
5555 results, map.getContext());
5556}
5557
5558namespace {
5559/// Store to load forwarding for transfer operations with permuation maps.
5560/// Even if the permutation maps are different we can still propagate the store
5561/// into the load if the size of the dimensions read and written match. Then we
5562/// can replace the transfer_read + transfer_write by vector.broadcast and
5563/// vector.transpose.
5564/// Example:
5565/// ```
5566/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
5567/// {in_bounds = [true, true],
5568/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
5569/// vector<4x1xf32>, tensor<4x4x4xf32>
5570/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
5571/// {in_bounds = [true, true, true, true],
5572/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
5573/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
5574/// ```
5575/// To:
5576/// ```
5577/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
5578/// %r = vector.transpose %0, [3, 0, 2, 1] :
5579/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
5580/// ```
5581struct TransferReadAfterWriteToBroadcast
5582 : public OpRewritePattern<TransferReadOp> {
5583 using Base::Base;
5584
5585 LogicalResult matchAndRewrite(TransferReadOp readOp,
5586 PatternRewriter &rewriter) const override {
5587 auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
5588 if (!defWrite)
5589 return failure();
5590 // Bail if we need an alias analysis.
5591 if (!readOp.hasPureTensorSemantics() || !defWrite.hasPureTensorSemantics())
5592 return failure();
5593 // Bail in the masked case (too complex atm and needed to properly account
5594 // for padding).
5595 if (readOp.getMask() || defWrite.getMask())
5596 return failure();
5597 // If indices are not the same a shift may be required, bail.
5598 if (readOp.getIndices() != defWrite.getIndices())
5599 return failure();
5600 // Bail if we need a bounds analysis.
5601 if (readOp.hasOutOfBoundsDim() || defWrite.hasOutOfBoundsDim())
5602 return failure();
5603 // TODO: If the written transfer chunk is a superset of the read transfer
5604 // chunk we could do an extract_strided_slice.
5605 if (readOp.getTransferChunkAccessed() !=
5606 defWrite.getTransferChunkAccessed())
5607 return failure();
5608 // WriteMap: tensor -> w_vec
5609 // ReadMap: tensor -> r_vec
5610 //
5611 // inv(WriteMap): w_vec -> tensor
5612 // inv(WriteMap) o ReadMap: w_vec -> r_vec
5613 AffineMap readMap = readOp.getPermutationMap();
5614 AffineMap writeMap = defWrite.getPermutationMap();
5615 AffineMap invWriteMap = inverseWithUnusedDims(writeMap);
5616 AffineMap composedMap = readMap.compose(invWriteMap);
5617 // If there are any unused dims in the composedMap, we have to drop some
5618 // unit dims from the written vector before we can do transpose(broadcast).
5619 // TODO: Support this case.
5620 if (getUnusedDimsBitVector(composedMap).any())
5621 return failure();
5622 // readVec = transpose(broadcast(writeVec))
5623 //
5624 // Build a transpose permutation for the above transpose operation.
5625 //
5626 // Treat the composed map as having extra leading dimensions which are
5627 // the broadcasted dimensions, and treat the zeros as these new broadcasted
5628 // dimensions.
5629 SmallVector<unsigned> broadcastedDims = composedMap.getBroadcastDims();
5630 int64_t numBroadcastedDims = broadcastedDims.size();
5631 auto invPerm = llvm::to_vector_of<int64_t>(broadcastedDims);
5632 invPerm.resize(composedMap.getNumResults());
5633 for (auto [idx, expr] : llvm::enumerate(composedMap.getResults())) {
5634 if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
5635 int64_t effectiveDim = dim.getPosition() + numBroadcastedDims;
5636 invPerm[effectiveDim] = idx;
5637 }
5638 }
5639 // Applying the inverse permutation on the readVecTy will give us the
5640 // broadcast result type.
5641 VectorType readVecTy = readOp.getVectorType();
5642 SmallVector<int64_t> permutation = invertPermutationVector(invPerm);
5643 auto broadcastedVecTy =
5644 VectorType::get(applyPermutation(readVecTy.getShape(), invPerm),
5645 readVecTy.getElementType(),
5646 applyPermutation(readVecTy.getScalableDims(), invPerm));
5647 // Build the transpose(broadcast) transformation.
5648 Value vec = defWrite.getVector();
5649 Location loc = readOp.getLoc();
5650 vec = vector::BroadcastOp::create(rewriter, loc, broadcastedVecTy, vec);
5651 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, permutation);
5652 return success();
5653 }
5654};
5655} // namespace
5656
5657void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
5658 MLIRContext *context) {
5659 results.add<TransferReadAfterWriteToBroadcast>(context);
5660}
5661
5662FailureOr<std::optional<SmallVector<Value>>>
5663TransferReadOp::bubbleDownCasts(OpBuilder &builder) {
5664 if (!hasPureBufferSemantics())
5665 return failure();
5667 getResult());
5668}
5669
5670//===----------------------------------------------------------------------===//
5671// TransferWriteOp
5672//===----------------------------------------------------------------------===//
5673
5674/// 1. Builder with type inference.
5675void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5676 Value vector, Value dest, ValueRange indices,
5677 AffineMapAttr permutationMapAttr,
5678 /*optional*/ Value mask,
5679 /*optional*/ ArrayAttr inBoundsAttr) {
5680 Type resultType = llvm::dyn_cast<RankedTensorType>(dest.getType());
5681 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
5682 mask, inBoundsAttr);
5683}
5684
5685/// 2. Builder with type inference that sets an empty mask (variant with attrs).
5686void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5687 Value vector, Value dest, ValueRange indices,
5688 AffineMapAttr permutationMapAttr,
5689 /*optional*/ ArrayAttr inBoundsAttr) {
5690 build(builder, result, vector, dest, indices, permutationMapAttr,
5691 /*mask=*/Value(), inBoundsAttr);
5692}
5693
5694/// 3. Builder with type inference that sets an empty mask (variant without
5695/// attrs). If `permutationMap` is null, a minor identity map is used.
5696void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5697 Value vector, Value dest, ValueRange indices,
5698 AffineMap permutationMap,
5699 std::optional<ArrayRef<bool>> inBounds) {
5700 if (!permutationMap)
5701 permutationMap =
5702 getTransferMinorIdentityMap(llvm::cast<ShapedType>(dest.getType()),
5703 llvm::cast<VectorType>(vector.getType()));
5704 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
5705 auto inBoundsAttr =
5706 (inBounds && !inBounds.value().empty())
5707 ? builder.getBoolArrayAttr(inBounds.value())
5708 : builder.getBoolArrayAttr(SmallVector<bool>(
5709 llvm::cast<VectorType>(vector.getType()).getRank(), false));
5710 build(builder, result, vector, dest, indices, permutationMapAttr,
5711 /*mask=*/Value(), inBoundsAttr);
5712}
5713
5714/// 4. Builder with type inference that sets an empty mask and sets permutation
5715/// map to 'getMinorIdentityMap'.
5716void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
5717 Value vector, Value dest, ValueRange indices,
5718 std::optional<ArrayRef<bool>> inBounds) {
5719 build(builder, result, vector, dest, indices, /*permutationMap=*/AffineMap(),
5720 inBounds);
5721}
5722
5723ParseResult TransferWriteOp::parse(OpAsmParser &parser,
5724 OperationState &result) {
5725 auto &builder = parser.getBuilder();
5726 SMLoc typesLoc;
5727 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
5728 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
5729 SmallVector<Type, 2> types;
5730 OpAsmParser::UnresolvedOperand maskInfo;
5731 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
5732 parser.parseOperand(sourceInfo) ||
5733 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
5734 return failure();
5735 ParseResult hasMask = parser.parseOptionalComma();
5736 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
5737 return failure();
5738 if (parser.parseOptionalAttrDict(result.attributes) ||
5739 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
5740 return failure();
5741 if (types.size() != 2)
5742 return parser.emitError(typesLoc, "requires two types");
5743 auto indexType = builder.getIndexType();
5744 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
5745 if (!vectorType)
5746 return parser.emitError(typesLoc, "requires vector type");
5747 ShapedType shapedType = llvm::dyn_cast<ShapedType>(types[1]);
5748 if (!shapedType || !llvm::isa<MemRefType, RankedTensorType>(shapedType))
5749 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
5750 auto permMapAttrName =
5751 TransferWriteOp::getPermutationMapAttrName(result.name);
5752 auto permMapAttr = result.attributes.get(permMapAttrName);
5753 AffineMap permMap;
5754 if (!permMapAttr) {
5755 if (shapedType.getRank() <
5756 getEffectiveVectorRankForXferOp(shapedType, vectorType))
5757 return parser.emitError(typesLoc,
5758 "expected a custom permutation_map when "
5759 "rank(source) != rank(destination)");
5760 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
5761 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
5762 } else {
5763 permMap = llvm::cast<AffineMapAttr>(permMapAttr).getValue();
5764 }
5765 auto inBoundsAttrName = TransferWriteOp::getInBoundsAttrName(result.name);
5766 Attribute inBoundsAttr = result.attributes.get(inBoundsAttrName);
5767 if (!inBoundsAttr) {
5768 result.addAttribute(inBoundsAttrName,
5769 builder.getBoolArrayAttr(
5770 SmallVector<bool>(permMap.getNumResults(), false)));
5771 }
5772 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
5773 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
5774 parser.resolveOperands(indexInfo, indexType, result.operands))
5775 return failure();
5776 if (hasMask.succeeded()) {
5777 if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
5778 return parser.emitError(
5779 maskInfo.location, "does not support masks with vector element type");
5780 if (vectorType.getRank() != permMap.getNumResults()) {
5781 return parser.emitError(typesLoc,
5782 "expected the same rank for the vector and the "
5783 "results of the permutation map");
5784 }
5785 auto maskType = inferTransferOpMaskType(vectorType, permMap);
5786 if (parser.resolveOperand(maskInfo, maskType, result.operands))
5787 return failure();
5788 }
5789 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
5790 builder.getDenseI32ArrayAttr(
5791 {1, 1, static_cast<int32_t>(indexInfo.size()),
5792 static_cast<int32_t>(hasMask.succeeded())}));
5793 return failure(llvm::isa<RankedTensorType>(shapedType) &&
5794 parser.addTypeToList(shapedType, result.types));
5795}
5796
5797void TransferWriteOp::print(OpAsmPrinter &p) {
5798 p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
5799 if (getMask())
5800 p << ", " << getMask();
5801 printTransferAttrs(p, *this);
5802 p << " : " << getVectorType() << ", " << getShapedType();
5803}
5804
5805LogicalResult TransferWriteOp::verify() {
5806 // Consistency of elemental types in shape and vector.
5807 ShapedType shapedType = getShapedType();
5808 VectorType vectorType = getVectorType();
5809 VectorType maskType = getMaskType();
5810 auto permutationMap = getPermutationMap();
5811 VectorType inferredMaskType =
5812 maskType ? inferTransferOpMaskType(vectorType, permutationMap)
5813 : VectorType();
5814
5815 if (llvm::size(getIndices()) != shapedType.getRank())
5816 return emitOpError("requires ") << shapedType.getRank() << " indices";
5817
5818 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
5819 // as the semantics is unclear. This can be revisited later if necessary.
5820 if (hasBroadcastDim())
5821 return emitOpError("should not have broadcast dimensions");
5822
5823 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
5824 shapedType, vectorType, maskType,
5825 inferredMaskType, permutationMap, getInBounds())))
5826 return failure();
5827
5828 return verifyPermutationMap(permutationMap,
5829 [&](Twine t) { return emitOpError(t); });
5830}
5831
5832//===----------------------------------------------------------------------===//
5833// TransferWriteOp: MaskableOpInterface methods.
5834//===----------------------------------------------------------------------===//
5835
5836/// Returns the mask type expected by this operation. Mostly used for
5837/// verification purposes.
5838Type TransferWriteOp::getExpectedMaskType() {
5839 return inferTransferOpMaskType(getVectorType(), getPermutationMap());
5840}
5841
5842//===----------------------------------------------------------------------===//
5843// TransferWriteOp: VectorTransferOpInterface methods.
5844//===----------------------------------------------------------------------===//
5845Value TransferWriteOp::getVector() { return getOperand(0); }
5846VectorType TransferWriteOp::getVectorType() {
5847 return cast<VectorType>(getValueToStore().getType());
5848}
5849
5850//===----------------------------------------------------------------------===//
5851// TransferWriteOp: fold methods.
5852//===----------------------------------------------------------------------===//
5853/// Fold:
5854/// ```
5855/// %t1 = ...
5856/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
5857/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5858/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
5859/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5860/// ```
5861///
5862/// into:
5863///
5864/// ```
5865/// %t0
5866/// ```
5867///
5868/// The producer of t1 may or may not be DCE'd depending on whether it is a
5869/// block argument or has side effects.
5870static LogicalResult foldReadInitWrite(TransferWriteOp write,
5871 ArrayRef<Attribute>,
5872 SmallVectorImpl<OpFoldResult> &results) {
5873 // TODO: support 0-d corner case.
5874 if (write.getTransferRank() == 0)
5875 return failure();
5876 auto rankedTensorType =
5877 llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
5878 // If not operating on tensors, bail.
5879 if (!rankedTensorType)
5880 return failure();
5881 // If no read, bail.
5882 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5883 if (!read)
5884 return failure();
5885 // TODO: support 0-d corner case.
5886 if (read.getTransferRank() == 0)
5887 return failure();
5888 // For now, only accept minor identity. Future: composition is minor identity.
5889 if (!read.getPermutationMap().isMinorIdentity() ||
5890 !write.getPermutationMap().isMinorIdentity())
5891 return failure();
5892 // Bail on mismatching ranks.
5893 if (read.getTransferRank() != write.getTransferRank())
5894 return failure();
5895 // Bail on potential out-of-bounds accesses.
5896 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
5897 return failure();
5898 // Masked transfers have padding/select semantics and are not identity folds.
5899 if (read.getMask() || write.getMask())
5900 return failure();
5901 // Tensor types must be the same.
5902 if (read.getBase().getType() != rankedTensorType)
5903 return failure();
5904 // Vector types must be the same.
5905 if (read.getVectorType() != write.getVectorType())
5906 return failure();
5907 // Vector and Tensor shapes must match.
5908 if (read.getVectorType().getShape() != rankedTensorType.getShape())
5909 return failure();
5910 // If any index is nonzero.
5911 auto isNotConstantZero = [](Value v) {
5912 auto cstOp = getConstantIntValue(v);
5913 return !cstOp.has_value() || cstOp.value() != 0;
5914 };
5915 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
5916 llvm::any_of(write.getIndices(), isNotConstantZero))
5917 return failure();
5918 // Success.
5919 results.push_back(read.getBase());
5920 return success();
5921}
5922
5923static bool checkSameValueWAR(vector::TransferReadOp read,
5924 vector::TransferWriteOp write) {
5925 return read.getBase() == write.getBase() &&
5926 read.getIndices() == write.getIndices() &&
5927 read.getPermutationMap() == write.getPermutationMap() &&
5928 read.getVectorType() == write.getVectorType() && !read.getMask() &&
5929 !write.getMask();
5930}
5931/// Fold transfer_write write after read:
5932/// ```
5933/// %t0 = ...
5934/// %v = vector.transfer_read %t0[%c0...] :
5935/// tensor<static_sizesxf32>, vector<static_sizesxf32>
5936/// %t1 = vector.transfer_write %v, %t0[%c0...] :
5937/// vector<static_sizesxf32>, tensor<static_sizesxf32>
5938/// ```
5939///
5940/// into:
5941///
5942/// ```
5943/// %t0
5944/// ```
5945static LogicalResult foldWAR(TransferWriteOp write,
5946 SmallVectorImpl<OpFoldResult> &results) {
5947 if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
5948 return failure();
5949 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
5950 if (!read)
5951 return failure();
5952
5953 if (!checkSameValueWAR(read, write))
5954 return failure();
5955 results.push_back(read.getBase());
5956 return success();
5957}
5958
5959LogicalResult TransferWriteOp::fold(FoldAdaptor adaptor,
5960 SmallVectorImpl<OpFoldResult> &results) {
5961 if (succeeded(foldReadInitWrite(*this, adaptor.getOperands(), results)))
5962 return success();
5963 if (succeeded(foldWAR(*this, results)))
5964 return success();
5965 if (succeeded(foldTransferInBoundsAttribute(*this)))
5966 return success();
5967 if (succeeded(foldTransferFullMask(*this)))
5968 return success();
5969 if (succeeded(foldSize1TransferPermutationMap(*this)))
5970 return success();
5971 return memref::foldMemRefCast(*this);
5972}
5973
5974//===----------------------------------------------------------------------===//
5975// TransferWriteOp: other methods.
5976//===----------------------------------------------------------------------===//
5977std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
5978 return llvm::to_vector<4>(getVectorType().getShape());
5979}
5980
5981void TransferWriteOp::getEffects(
5982 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5983 &effects) {
5984 if (llvm::isa<MemRefType>(getShapedType()))
5985 effects.emplace_back(MemoryEffects::Write::get(), &getBaseMutable(),
5986 SideEffects::DefaultResource::get());
5987}
5988
5989Speculation::Speculatability TransferWriteOp::getSpeculatability() {
5990 if (hasPureTensorSemantics())
5993}
5994
5995namespace {
5996/// Remove dead transfer write from the SSA chain so that it an be eliminated by
5997/// DCE
5998/// ```
5999/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
6000/// : vector<1x4xf32>, tensor<4x4xf32>
6001/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
6002/// : vector<1x4xf32>, tensor<4x4xf32>
6003/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
6004/// : vector<1x4xf32>, tensor<4x4xf32>
6005/// ```
6006///
6007/// into:
6008///
6009/// ```
6010/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
6011/// : vector<1x4xf32>, tensor<4x4xf32>
6012/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
6013/// : vector<1x4xf32>, tensor<4x4xf32>
6014/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
6015/// : vector<1x4xf32>, tensor<4x4xf32>
6016/// ```
6017///
6018/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
6019/// any other uses.
6020class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
6021public:
6022 using Base::Base;
6023 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
6024 PatternRewriter &rewriter) const override {
6025 if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
6026 return failure();
6027 vector::TransferWriteOp writeToModify = writeOp;
6028
6029 auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
6030 while (defWrite) {
6031 if (checkSameValueWAW(writeOp, defWrite)) {
6032 rewriter.modifyOpInPlace(writeToModify, [&]() {
6033 writeToModify.getBaseMutable().assign(defWrite.getBase());
6034 });
6035 return success();
6036 }
6038 cast<VectorTransferOpInterface>(defWrite.getOperation()),
6039 cast<VectorTransferOpInterface>(writeOp.getOperation())))
6040 break;
6041 // If the previous write op doesn't have any other use we an safely look
6042 // at the previous store to see if it can be removed.
6043 if (!defWrite->hasOneUse())
6044 break;
6045 writeToModify = defWrite;
6046 defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
6047 }
6048 return failure();
6049 }
6050};
6051
6052/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
6053/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
6054/// overwritten and inserted into another tensor. After this rewrite, the
6055/// operations bufferize in-place since all of them work on the same slice.
6056///
6057/// For example:
6058/// ```mlir
6059/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
6060/// : vector<8x16xf32>, tensor<8x16xf32>
6061/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
6062/// : tensor<8x16xf32> to tensor<?x?xf32>
6063/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6064/// : tensor<?x?xf32> into tensor<27x37xf32>
6065/// ```
6066/// folds to
6067/// ```mlir
6068/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6069/// : tensor<27x37xf32> to tensor<?x?xf32>
6070/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
6071/// : vector<8x16xf32>, tensor<?x?xf32>
6072/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
6073/// : tensor<?x?xf32> into tensor<27x37xf32>
6074/// ```
6075struct SwapExtractSliceOfTransferWrite
6076 : public OpRewritePattern<tensor::InsertSliceOp> {
6077public:
6078 using Base::Base;
6079
6080 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
6081 PatternRewriter &rewriter) const override {
6082 if (!insertOp.hasUnitStride())
6083 return failure();
6084 auto extractOp =
6085 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
6086 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
6087 return failure();
6088 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
6089 if (!transferOp || !transferOp->hasOneUse())
6090 return failure();
6091
6092 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
6093 // rank-reducing.
6094 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
6095 return rewriter.notifyMatchFailure(insertOp,
6096 "use-def chain is rank-reducing");
6097 }
6098
6099 // Fail if tensor::ExtractSliceOp has non-zero offset.
6100 if (!extractOp.hasZeroOffset()) {
6101 return rewriter.notifyMatchFailure(insertOp,
6102 "ExtractSliceOp has non-zero offset");
6103 }
6104
6105 // Fail if tensor::TransferWriteOp has non-zero offset.
6106 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
6107 return getConstantIntValue(value) == static_cast<int64_t>(0);
6108 })) {
6109 return rewriter.notifyMatchFailure(insertOp,
6110 "TranferWriteOp has non-zero offset");
6111 }
6112
6113 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
6114 if (insertOp.getMixedSizes().size() != extractOp.getMixedSizes().size()) {
6115 return rewriter.notifyMatchFailure(
6116 insertOp, "InsertSliceOp and ExtractSliceOp ranks differ");
6117 }
6118
6119 for (auto [insertSize, extractSize] :
6120 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
6121 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
6122 return rewriter.notifyMatchFailure(
6123 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
6124 }
6125 }
6126
6127 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
6128 assert(transferOp.getVectorType().hasStaticShape() &&
6129 "expected vector to have a static shape");
6130 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
6131 SmallVector<int64_t> resultShape = applyPermutationMap(
6132 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
6133 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
6134 return rewriter.notifyMatchFailure(
6135 insertOp, "TransferWriteOp may not write the full tensor.");
6136 }
6137
6138 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
6139 // Set all in_bounds to false and let the folder infer them.
6140 SmallVector<bool> newInBounds(vectorShape.size(), false);
6141 auto newExtractOp = tensor::ExtractSliceOp::create(
6142 rewriter, extractOp.getLoc(), insertOp.getSourceType(),
6143 insertOp.getDest(), insertOp.getMixedOffsets(),
6144 insertOp.getMixedSizes(), insertOp.getMixedStrides());
6145 auto newTransferWriteOp = TransferWriteOp::create(
6146 rewriter, transferOp.getLoc(), transferOp.getVector(),
6147 newExtractOp.getResult(), transferOp.getIndices(),
6148 transferOp.getPermutationMapAttr(),
6149 rewriter.getBoolArrayAttr(newInBounds));
6150 rewriter.modifyOpInPlace(insertOp, [&]() {
6151 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
6152 });
6153 return success();
6154 }
6155};
6156
6157} // namespace
6158
6159void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
6160 MLIRContext *context) {
6161 results.add<FoldWaw, SwapExtractSliceOfTransferWrite>(context);
6162}
6163
6164FailureOr<std::optional<SmallVector<Value>>>
6165TransferWriteOp::bubbleDownCasts(OpBuilder &builder) {
6166 if (!hasPureBufferSemantics())
6167 return failure();
6169 ValueRange());
6170}
6171
6172//===----------------------------------------------------------------------===//
6173// LoadOp
6174//===----------------------------------------------------------------------===//
6175
6176static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
6177 VectorType vecTy,
6178 MemRefType memRefTy) {
6179 // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
6180 // need any strides limitations.
6181 if (!vecTy.isScalable() &&
6182 (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
6183 return success();
6184
6185 if (!memRefTy.isLastDimUnitStride())
6186 return op->emitOpError("most minor memref dim must have unit stride");
6187 return success();
6188}
6189
6190LogicalResult vector::LoadOp::verify() {
6191 VectorType resVecTy = getVectorType();
6192 MemRefType memRefTy = getMemRefType();
6193
6194 if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
6195 return failure();
6196
6197 if (memRefTy.getRank() < resVecTy.getRank())
6198 return emitOpError(
6199 "destination memref has lower rank than the result vector");
6200
6201 // Checks for vector memrefs.
6202 Type memElemTy = memRefTy.getElementType();
6203 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6204 if (memVecTy != resVecTy)
6205 return emitOpError("base memref and result vector types should match");
6206 memElemTy = memVecTy.getElementType();
6207 }
6208
6209 if (resVecTy.getElementType() != memElemTy)
6210 return emitOpError("base and result element types should match");
6211 if (llvm::size(getIndices()) != memRefTy.getRank())
6212 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6213 return success();
6214}
6215
6216OpFoldResult LoadOp::fold(FoldAdaptor) {
6217 if (succeeded(memref::foldMemRefCast(*this)))
6218 return getResult();
6219 return OpFoldResult();
6220}
6221
6222std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
6223 return llvm::to_vector<4>(getVectorType().getShape());
6224}
6225
6226FailureOr<std::optional<SmallVector<Value>>>
6227LoadOp::bubbleDownCasts(OpBuilder &builder) {
6229 getResult());
6230}
6231
6232//===----------------------------------------------------------------------===//
6233// StoreOp
6234//===----------------------------------------------------------------------===//
6235
6236LogicalResult vector::StoreOp::verify() {
6237 VectorType valueVecTy = getVectorType();
6238 MemRefType memRefTy = getMemRefType();
6239
6240 if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
6241 return failure();
6242
6243 if (memRefTy.getRank() < valueVecTy.getRank())
6244 return emitOpError("source memref has lower rank than the vector to store");
6245
6246 // Checks for vector memrefs.
6247 Type memElemTy = memRefTy.getElementType();
6248 if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
6249 if (memVecTy != valueVecTy)
6250 return emitOpError(
6251 "base memref and valueToStore vector types should match");
6252 memElemTy = memVecTy.getElementType();
6253 }
6254
6255 if (valueVecTy.getElementType() != memElemTy)
6256 return emitOpError("base and valueToStore element type should match");
6257 if (llvm::size(getIndices()) != memRefTy.getRank())
6258 return emitOpError("requires ") << memRefTy.getRank() << " indices";
6259 return success();
6260}
6261
6262LogicalResult StoreOp::fold(FoldAdaptor adaptor,
6263 SmallVectorImpl<OpFoldResult> &results) {
6264 return memref::foldMemRefCast(*this);
6265}
6266
6267std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
6268 return llvm::to_vector<4>(getVectorType().getShape());
6269}
6270
6271FailureOr<std::optional<SmallVector<Value>>>
6272StoreOp::bubbleDownCasts(OpBuilder &builder) {
6274 ValueRange());
6275}
6276
6277//===----------------------------------------------------------------------===//
6278// MaskedLoadOp
6279//===----------------------------------------------------------------------===//
6280
6281LogicalResult MaskedLoadOp::verify() {
6282 VectorType maskVType = getMaskVectorType();
6283 VectorType passVType = getPassThruVectorType();
6284 VectorType resVType = getVectorType();
6285 MemRefType memType = getMemRefType();
6286
6287 if (failed(
6288 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6289 return failure();
6290 if (llvm::size(getIndices()) != memType.getRank())
6291 return emitOpError("requires ") << memType.getRank() << " indices";
6292 if (resVType.getShape() != maskVType.getShape())
6293 return emitOpError("expected result shape to match mask shape");
6294 if (resVType != passVType)
6295 return emitOpError("expected pass_thru of same type as result type");
6296 return success();
6297}
6298
6299namespace {
6300class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
6301public:
6302 using Base::Base;
6303 LogicalResult matchAndRewrite(MaskedLoadOp load,
6304 PatternRewriter &rewriter) const override {
6305 switch (getMaskFormat(load.getMask())) {
6307 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6308 load, load.getType(), load.getBase(), load.getIndices());
6309 return success();
6311 rewriter.replaceOp(load, load.getPassThru());
6312 return success();
6314 return failure();
6315 }
6316 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad");
6317 }
6318};
6319} // namespace
6320
6321void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6322 MLIRContext *context) {
6323 results.add<MaskedLoadFolder>(context);
6324}
6325
6326OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
6327 if (succeeded(memref::foldMemRefCast(*this)))
6328 return getResult();
6329 return OpFoldResult();
6330}
6331
6332FailureOr<std::optional<SmallVector<Value>>>
6333MaskedLoadOp::bubbleDownCasts(OpBuilder &builder) {
6335 getResult());
6336}
6337
6338//===----------------------------------------------------------------------===//
6339// MaskedStoreOp
6340//===----------------------------------------------------------------------===//
6341
6342LogicalResult MaskedStoreOp::verify() {
6343 VectorType maskVType = getMaskVectorType();
6344 VectorType valueVType = getVectorType();
6345 MemRefType memType = getMemRefType();
6346
6347 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6348 "valueToStore")))
6349 return failure();
6350 if (llvm::size(getIndices()) != memType.getRank())
6351 return emitOpError("requires ") << memType.getRank() << " indices";
6352 if (valueVType.getShape() != maskVType.getShape())
6353 return emitOpError("expected valueToStore shape to match mask shape");
6354 return success();
6355}
6356
6357namespace {
6358class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
6359public:
6360 using Base::Base;
6361 LogicalResult matchAndRewrite(MaskedStoreOp store,
6362 PatternRewriter &rewriter) const override {
6363 switch (getMaskFormat(store.getMask())) {
6365 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6366 store, store.getValueToStore(), store.getBase(), store.getIndices());
6367 return success();
6369 rewriter.eraseOp(store);
6370 return success();
6372 return failure();
6373 }
6374 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore");
6375 }
6376};
6377} // namespace
6378
6379void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6380 MLIRContext *context) {
6381 results.add<MaskedStoreFolder>(context);
6382}
6383
6384LogicalResult MaskedStoreOp::fold(FoldAdaptor adaptor,
6385 SmallVectorImpl<OpFoldResult> &results) {
6386 return memref::foldMemRefCast(*this);
6387}
6388
6389FailureOr<std::optional<SmallVector<Value>>>
6390MaskedStoreOp::bubbleDownCasts(OpBuilder &builder) {
6392 ValueRange());
6393}
6394
6395//===----------------------------------------------------------------------===//
6396// GatherOp
6397//===----------------------------------------------------------------------===//
6398
6399LogicalResult GatherOp::verify() {
6400 VectorType indVType = getIndexVectorType();
6401 VectorType maskVType = getMaskVectorType();
6402 VectorType resVType = getVectorType();
6403 ShapedType baseType = getBaseType();
6404
6405 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6406 return emitOpError("requires base to be a memref or ranked tensor type");
6407
6408 if (failed(
6409 verifyElementTypesMatch(*this, baseType, resVType, "base", "result")))
6410 return failure();
6411 if (llvm::size(getOffsets()) != baseType.getRank())
6412 return emitOpError("requires ") << baseType.getRank() << " indices";
6413 if (resVType.getShape() != indVType.getShape())
6414 return emitOpError("expected result dim to match indices dim");
6415 if (resVType.getShape() != maskVType.getShape())
6416 return emitOpError("expected result dim to match mask dim");
6417 if (resVType != getPassThruVectorType())
6418 return emitOpError("expected pass_thru of same type as result type");
6419 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6420 return emitOpError(
6421 "alignment is only supported for memref bases, not tensor bases");
6422 }
6423 return success();
6424}
6425
6426// MaskableOpInterface methods.
6427
6428/// Returns the mask type expected by this operation. Mostly used for
6429/// verification purposes. It requires the operation to be vectorized."
6430Type GatherOp::getExpectedMaskType() {
6431 auto vecType = this->getIndexVectorType();
6432 return VectorType::get(vecType.getShape(),
6433 IntegerType::get(vecType.getContext(), /*width=*/1),
6434 vecType.getScalableDims());
6435}
6436
6437std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
6438 return llvm::to_vector<4>(getVectorType().getShape());
6439}
6440
6441/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
6442static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
6443 auto vecType = dyn_cast<VectorType>(indexVec.getType());
6444 if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
6445 return failure();
6446
6447 if (indexVec.getDefiningOp<StepOp>())
6448 return success();
6449
6450 DenseIntElementsAttr elements;
6451 if (!matchPattern(indexVec, m_Constant(&elements)))
6452 return failure();
6453
6454 return success(
6455 llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
6456}
6457
6458namespace {
6459class GatherFolder final : public OpRewritePattern<GatherOp> {
6460public:
6461 using Base::Base;
6462 LogicalResult matchAndRewrite(GatherOp gather,
6463 PatternRewriter &rewriter) const override {
6464 switch (getMaskFormat(gather.getMask())) {
6466 return failure(); // no unmasked equivalent
6468 rewriter.replaceOp(gather, gather.getPassThru());
6469 return success();
6471 return failure();
6472 }
6473 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
6474 }
6475};
6476
6477/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
6478/// maskedload. Only 1D fixed vectors are supported for now.
6479class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
6480public:
6481 using Base::Base;
6482 LogicalResult matchAndRewrite(GatherOp op,
6483 PatternRewriter &rewriter) const override {
6484 if (!isa<MemRefType>(op.getBase().getType()))
6485 return rewriter.notifyMatchFailure(op, "base must be of memref type");
6486
6487 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6488 return failure();
6489
6490 rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
6491 op.getOffsets(), op.getMask(),
6492 op.getPassThru());
6493 return success();
6494 }
6495};
6496} // namespace
6497
6498void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
6499 MLIRContext *context) {
6500 results.add<GatherFolder, FoldContiguousGather>(context);
6501}
6502
6503FailureOr<std::optional<SmallVector<Value>>>
6504GatherOp::bubbleDownCasts(OpBuilder &builder) {
6506 getResult());
6507}
6508
6509//===----------------------------------------------------------------------===//
6510// ScatterOp
6511//===----------------------------------------------------------------------===//
6512
6513LogicalResult ScatterOp::verify() {
6514 VectorType indVType = getIndexVectorType();
6515 VectorType maskVType = getMaskVectorType();
6516 VectorType valueVType = getVectorType();
6517 ShapedType baseType = getBaseType();
6518
6519 if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6520 return emitOpError("requires base to be a memref or ranked tensor type");
6521
6522 if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
6523 "valueToStore")))
6524 return failure();
6525 if (llvm::size(getOffsets()) != baseType.getRank())
6526 return emitOpError("requires ") << baseType.getRank() << " indices";
6527 if (valueVType.getShape() != indVType.getShape())
6528 return emitOpError("expected valueToStore dim to match indices dim");
6529 if (valueVType.getShape() != maskVType.getShape())
6530 return emitOpError("expected valueToStore dim to match mask dim");
6531 if (getAlignmentAttr() && !isa<MemRefType>(baseType)) {
6532 return emitOpError(
6533 "alignment is only supported for memref bases, not tensor bases");
6534 }
6535 return success();
6536}
6537namespace {
6538class ScatterFolder final : public OpRewritePattern<ScatterOp> {
6539public:
6540 using Base::Base;
6541 LogicalResult matchAndRewrite(ScatterOp scatter,
6542 PatternRewriter &rewriter) const override {
6543 ShapedType baseType = scatter.getBaseType();
6544 bool isMemRef = isa<MemRefType>(baseType);
6545 if (!isMemRef && !isa<RankedTensorType>(baseType))
6546 return failure();
6547
6548 // Memrefs have no result, so an all-false mask can simply erase the op.
6549 // Tensors carry the updated value, so we must replace uses with the
6550 // original base tensor instead of erasing.
6551 switch (getMaskFormat(scatter.getMask())) {
6553 return failure(); // no unmasked equivalent
6555 if (isMemRef)
6556 rewriter.eraseOp(scatter);
6557 else
6558 rewriter.replaceOp(scatter, scatter.getBase());
6559 return success();
6561 return failure();
6562 }
6563 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
6564 }
6565};
6566
6567/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
6568/// maskedstore. Only 1D fixed vectors are supported for now.
6569class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
6570public:
6571 using Base::Base;
6572 LogicalResult matchAndRewrite(ScatterOp op,
6573 PatternRewriter &rewriter) const override {
6574 // Fold only for memrefs: the replacement uses maskedstore, which does not
6575 // support tensor bases. Tensor cases intentionally bail out.
6576 if (!isa<MemRefType>(op.getBase().getType()))
6577 return failure();
6578
6579 if (failed(isZeroBasedContiguousSeq(op.getIndices())))
6580 return failure();
6581
6582 rewriter.replaceOpWithNewOp<MaskedStoreOp>(
6583 op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore());
6584 return success();
6585 }
6586};
6587} // namespace
6588
6589void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
6590 MLIRContext *context) {
6591 results.add<ScatterFolder, FoldContiguousScatter>(context);
6592}
6593
6594FailureOr<std::optional<SmallVector<Value>>>
6595ScatterOp::bubbleDownCasts(OpBuilder &builder) {
6597 ValueRange());
6598}
6599
6600//===----------------------------------------------------------------------===//
6601// ExpandLoadOp
6602//===----------------------------------------------------------------------===//
6603
6604LogicalResult ExpandLoadOp::verify() {
6605 VectorType maskVType = getMaskVectorType();
6606 VectorType passVType = getPassThruVectorType();
6607 VectorType resVType = getVectorType();
6608 MemRefType memType = getMemRefType();
6609
6610 if (failed(
6611 verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
6612 return failure();
6613 if (llvm::size(getIndices()) != memType.getRank())
6614 return emitOpError("requires ") << memType.getRank() << " indices";
6615 if (resVType.getShape() != maskVType.getShape())
6616 return emitOpError("expected result shape to match mask shape");
6617 if (resVType != passVType)
6618 return emitOpError("expected pass_thru of same type as result type");
6619 return success();
6620}
6621
6622namespace {
6623class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
6624public:
6625 using Base::Base;
6626 LogicalResult matchAndRewrite(ExpandLoadOp expand,
6627 PatternRewriter &rewriter) const override {
6628 switch (getMaskFormat(expand.getMask())) {
6630 rewriter.replaceOpWithNewOp<vector::LoadOp>(
6631 expand, expand.getType(), expand.getBase(), expand.getIndices());
6632 return success();
6634 rewriter.replaceOp(expand, expand.getPassThru());
6635 return success();
6637 return failure();
6638 }
6639 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder");
6640 }
6641};
6642} // namespace
6643
6644void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
6645 MLIRContext *context) {
6646 results.add<ExpandLoadFolder>(context);
6647}
6648
6649FailureOr<std::optional<SmallVector<Value>>>
6650ExpandLoadOp::bubbleDownCasts(OpBuilder &builder) {
6652 getResult());
6653}
6654
6655//===----------------------------------------------------------------------===//
6656// CompressStoreOp
6657//===----------------------------------------------------------------------===//
6658
6659LogicalResult CompressStoreOp::verify() {
6660 VectorType maskVType = getMaskVectorType();
6661 VectorType valueVType = getVectorType();
6662 MemRefType memType = getMemRefType();
6663
6664 if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
6665 "valueToStore")))
6666 return failure();
6667 if (llvm::size(getIndices()) != memType.getRank())
6668 return emitOpError("requires ") << memType.getRank() << " indices";
6669 if (valueVType.getShape() != maskVType.getShape())
6670 return emitOpError("expected valueToStore shape to match mask shape");
6671 return success();
6672}
6673
6674namespace {
6675class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
6676public:
6677 using Base::Base;
6678 LogicalResult matchAndRewrite(CompressStoreOp compress,
6679 PatternRewriter &rewriter) const override {
6680 switch (getMaskFormat(compress.getMask())) {
6682 rewriter.replaceOpWithNewOp<vector::StoreOp>(
6683 compress, compress.getValueToStore(), compress.getBase(),
6684 compress.getIndices());
6685 return success();
6687 rewriter.eraseOp(compress);
6688 return success();
6690 return failure();
6691 }
6692 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder");
6693 }
6694};
6695} // namespace
6696
6697void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
6698 MLIRContext *context) {
6699 results.add<CompressStoreFolder>(context);
6700}
6701
6702FailureOr<std::optional<SmallVector<Value>>>
6703CompressStoreOp::bubbleDownCasts(OpBuilder &builder) {
6705 ValueRange());
6706}
6707
6708//===----------------------------------------------------------------------===//
6709// ShapeCastOp
6710//===----------------------------------------------------------------------===//
6711
6712void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
6713 SetIntRangeFn setResultRanges) {
6714 setResultRanges(getResult(), argRanges.front());
6715}
6716
6717std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
6718 return llvm::to_vector<4>(getResultVectorType().getShape());
6719}
6720
6721LogicalResult ShapeCastOp::verify() {
6722
6723 VectorType sourceType = getSourceVectorType();
6724 VectorType resultType = getResultVectorType();
6725
6726 // Check that element type is preserved
6727 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
6728 "result")))
6729 return failure();
6730
6731 // Check that number of elements is preserved
6732 int64_t sourceNElms = sourceType.getNumElements();
6733 int64_t resultNElms = resultType.getNumElements();
6734 if (sourceNElms != resultNElms) {
6735 return emitOpError() << "has different number of elements at source ("
6736 << sourceNElms << ") and result (" << resultNElms
6737 << ")";
6738 }
6739
6740 // Check that (non-)scalability is preserved
6741 int64_t sourceNScalableDims = sourceType.getNumScalableDims();
6742 int64_t resultNScalableDims = resultType.getNumScalableDims();
6743 if (sourceNScalableDims != resultNScalableDims)
6744 return emitOpError() << "has different number of scalable dims at source ("
6745 << sourceNScalableDims << ") and result ("
6746 << resultNScalableDims << ")";
6747
6748 return success();
6749}
6750
6751/// Return true if `transpose` does not permute a pair of non-unit dims.
6752/// By `order preserving` we mean that the flattened versions of the input and
6753/// output vectors are (numerically) identical. In other words `transpose` is
6754/// effectively a shape cast.
6755static bool isOrderPreserving(TransposeOp transpose) {
6756 ArrayRef<int64_t> permutation = transpose.getPermutation();
6757 VectorType sourceType = transpose.getSourceVectorType();
6758 ArrayRef<int64_t> inShape = sourceType.getShape();
6759 ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6760 auto isNonScalableUnitDim = [&](int64_t dim) {
6761 return inShape[dim] == 1 && !inDimIsScalable[dim];
6762 };
6763 int64_t current = 0;
6764 for (auto p : permutation) {
6765 if (!isNonScalableUnitDim(p)) {
6766 if (p < current) {
6767 return false;
6768 }
6769 current = p;
6770 }
6771 }
6772 return true;
6773}
6774
6775OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
6776
6777 VectorType resultType = getType();
6778
6779 // No-op shape cast.
6780 if (getSource().getType() == resultType)
6781 return getSource();
6782
6783 // shape_cast(shape_cast(x)) -> shape_cast(x)
6784 if (auto precedingShapeCast = getSource().getDefiningOp<ShapeCastOp>()) {
6785 setOperand(precedingShapeCast.getSource());
6786 return getResult();
6787 }
6788
6789 // shape_cast(transpose(x)) -> shape_cast(x)
6790 if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
6791 if (isOrderPreserving(transpose)) {
6792 setOperand(transpose.getVector());
6793 return getResult();
6794 }
6795 return {};
6796 }
6797
6798 // Y = shape_cast(broadcast(X))
6799 // -> X, if X and Y have same type
6800 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
6801 if (bcastOp.getSourceType() == resultType)
6802 return bcastOp.getSource();
6803 }
6804
6805 // shape_cast(constant) -> constant
6806 if (auto denseAttr =
6807 dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
6808 return denseAttr.reshape(getType());
6809
6810 // shape_cast(poison) -> poison
6811 if (matchPattern(adaptor.getSource(), ub::m_Poison()))
6812 return ub::PoisonAttr::get(getContext());
6813
6814 return {};
6815}
6816
6817namespace {
6818
6819/// Helper function that computes a new vector type based on the input vector
6820/// type by removing the trailing one dims:
6821///
6822/// vector<4x1x1xi1> --> vector<4x1xi1>
6823///
6824static VectorType trimTrailingOneDims(VectorType oldType) {
6825 ArrayRef<int64_t> oldShape = oldType.getShape();
6826 ArrayRef<int64_t> newShape = oldShape;
6827
6828 ArrayRef<bool> oldScalableDims = oldType.getScalableDims();
6829 ArrayRef<bool> newScalableDims = oldScalableDims;
6830
6831 while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) {
6832 newShape = newShape.drop_back(1);
6833 newScalableDims = newScalableDims.drop_back(1);
6834 }
6835
6836 // Make sure we have at least 1 dimension.
6837 // TODO: Add support for 0-D vectors.
6838 if (newShape.empty()) {
6839 newShape = oldShape.take_back();
6840 newScalableDims = oldScalableDims.take_back();
6841 }
6842
6843 return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
6844}
6845
6846/// Folds qualifying shape_cast(create_mask) into a new create_mask
6847///
6848/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
6849/// dimension. If the input vector comes from `vector.create_mask` for which
6850/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6851/// to fold shape_cast into create_mask.
6852///
6853/// BEFORE:
6854/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
6855/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
6856/// AFTER:
6857/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1>
6858class ShapeCastCreateMaskFolderTrailingOneDim final
6859 : public OpRewritePattern<ShapeCastOp> {
6860public:
6861 using Base::Base;
6862
6863 LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
6864 PatternRewriter &rewriter) const override {
6865 Value shapeOpSrc = shapeOp->getOperand(0);
6866 auto createMaskOp = shapeOpSrc.getDefiningOp<vector::CreateMaskOp>();
6867 auto constantMaskOp = shapeOpSrc.getDefiningOp<vector::ConstantMaskOp>();
6868 if (!createMaskOp && !constantMaskOp)
6869 return failure();
6870
6871 VectorType shapeOpResTy = shapeOp.getResultVectorType();
6872 VectorType shapeOpSrcTy = shapeOp.getSourceVectorType();
6873
6874 VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy);
6875 if (newVecType != shapeOpResTy)
6876 return failure();
6877
6878 auto numDimsToDrop =
6879 shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size();
6880
6881 // No unit dims to drop
6882 if (!numDimsToDrop)
6883 return failure();
6884
6885 if (createMaskOp) {
6886 auto maskOperands = createMaskOp.getOperands();
6887 auto numMaskOperands = maskOperands.size();
6888
6889 // Check every mask dim size to see whether it can be dropped
6890 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6891 --i) {
6892 auto constant = maskOperands[i].getDefiningOp<arith::ConstantIndexOp>();
6893 if (!constant || (constant.value() != 1))
6894 return failure();
6895 }
6896 SmallVector<Value> newMaskOperands =
6897 maskOperands.drop_back(numDimsToDrop);
6898
6899 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(shapeOp, shapeOpResTy,
6900 newMaskOperands);
6901 return success();
6902 }
6903
6904 if (constantMaskOp) {
6905 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
6906 auto numMaskOperands = maskDimSizes.size();
6907
6908 // Check every mask dim size to see whether it can be dropped
6909 for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
6910 --i) {
6911 if (maskDimSizes[i] != 1)
6912 return failure();
6913 }
6914
6915 auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
6916 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
6917 newMaskOperands);
6918 return success();
6919 }
6920
6921 return failure();
6922 }
6923};
6924
6925// vector.broadcast has two distinct semantic modes: duplication across leading
6926// dimensions, and stretching across inner dimensions. This helper returns the
6927// product of the inner-dimension stretching factors.
6928int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
6929 ArrayRef<int64_t> dstShape) {
6930 int stretchingFactor = 1;
6931 int numLeadingDims = dstShape.size() - srcShape.size();
6932 for (int i = 0, e = srcShape.size(); i < e; i++) {
6933 int64_t dstDim = dstShape[numLeadingDims + i];
6934 if (srcShape[i] == 1 && dstDim != 1) {
6935 stretchingFactor *= dstDim;
6936 }
6937 }
6938 return stretchingFactor;
6939}
6940
6941/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
6942class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
6943public:
6944 using Base::Base;
6945
6946 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
6947 PatternRewriter &rewriter) const override {
6948 auto broadcastOp =
6949 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
6950 if (!broadcastOp)
6951 return failure();
6952
6953 auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
6954 bool srcIsScalar = !srcVectorType;
6955
6956 // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
6957 // Example
6958 // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
6959 // %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32>
6960 // to
6961 // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
6962 VectorType dstVectorType = shapeCastOp.getResultVectorType();
6963 ArrayRef<int64_t> dstShape = dstVectorType.getShape();
6964 ArrayRef<int64_t> srcShape =
6965 srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
6966 ArrayRef<int64_t> broadcastShape =
6967 broadcastOp.getResultVectorType().getShape();
6968
6969 if (!srcIsScalar) {
6970 if (isBroadcastableTo(srcVectorType, dstVectorType) !=
6971 BroadcastableToResult::Success) {
6972 return failure();
6973 }
6974 // Avoid folding if this would result in switching between the two
6975 // distinct semantic modes of vector.broadcast (duplication vs
6976 // stretching). See https://github.com/llvm/llvm-project/issues/190614.
6977 // This is detected by a change in the stretching factor. However if the
6978 // source has a single element, there is no ambiguity.
6979 if (srcVectorType.getNumElements() != 1) {
6980 if (getBroadcastStretchingFactor(srcShape, dstShape) !=
6981 getBroadcastStretchingFactor(srcShape, broadcastShape)) {
6982 return failure();
6983 }
6984 }
6985 }
6986
6987 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shapeCastOp, dstVectorType,
6988 broadcastOp.getSource());
6989 return success();
6990 }
6991};
6992
6993/// Pattern to rewrite Y = ShapeCast(FromElements(X)) as Y = FromElements(X)
6994///
6995/// BEFORE:
6996/// %1 = vector.from_elements %c1, %c2, %c3 : vector<3xf32>
6997/// %2 = vector.shape_cast %1 : vector<3xf32> to vector<1x3xf32>
6998/// AFTER:
6999/// %2 = vector.from_elements %c1, %c2, %c3 : vector<1x3xf32>
7000///
7001/// Note: this transformation is implemented as an OpRewritePattern, not as a
7002/// fold, because we have to create new op FromElementsOp with updated result
7003/// type. This cannot be done with a fold, because fold cannot create new ops
7004/// and the existing FromElementsOp result type differs from the ShapeCastOp
7005/// result type. Mutating the FromElementsOp (not root op) would violate the
7006/// fold contract and break other users.
7007class FoldShapeCastOfFromElements final : public OpRewritePattern<ShapeCastOp> {
7008public:
7009 using Base::Base;
7010
7011 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
7012 PatternRewriter &rewriter) const override {
7013 auto fromElements = shapeCastOp.getSource().getDefiningOp<FromElementsOp>();
7014 if (!fromElements)
7015 return failure();
7016
7017 rewriter.replaceOpWithNewOp<FromElementsOp>(
7018 shapeCastOp, shapeCastOp.getResultVectorType(),
7019 fromElements.getElements());
7020 return success();
7021 }
7022};
7023
7024} // namespace
7025
7026void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
7027 MLIRContext *context) {
7028 results.add<ShapeCastCreateMaskFolderTrailingOneDim, ShapeCastBroadcastFolder,
7029 FoldShapeCastOfFromElements>(context);
7030}
7031
7032//===----------------------------------------------------------------------===//
7033// VectorBitCastOp
7034//===----------------------------------------------------------------------===//
7035
7036LogicalResult BitCastOp::verify() {
7037 auto sourceVectorType = getSourceVectorType();
7038 auto resultVectorType = getResultVectorType();
7039
7040 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
7041 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
7042 return emitOpError("dimension size mismatch at: ") << i;
7043 }
7044
7045 DataLayout dataLayout = DataLayout::closest(*this);
7046 auto sourceElementBits =
7047 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
7048 auto resultElementBits =
7049 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
7050
7051 if (sourceVectorType.getRank() == 0) {
7052 if (sourceElementBits != resultElementBits)
7053 return emitOpError("source/result bitwidth of the 0-D vector element "
7054 "types must be equal");
7055 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
7056 resultElementBits * resultVectorType.getShape().back()) {
7057 return emitOpError(
7058 "source/result bitwidth of the minor 1-D vectors must be equal");
7059 }
7060
7061 return success();
7062}
7063
7064OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
7065 // Nop cast.
7066 if (getSource().getType() == getResult().getType())
7067 return getSource();
7068
7069 // Canceling bitcasts.
7070 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
7071 if (getResult().getType() == otherOp.getSource().getType())
7072 return otherOp.getSource();
7073
7074 setOperand(otherOp.getSource());
7075 return getResult();
7076 }
7077
7078 Attribute sourceConstant = adaptor.getSource();
7079 if (!sourceConstant)
7080 return {};
7081
7082 Type srcElemType = getSourceVectorType().getElementType();
7083 Type dstElemType = getResultVectorType().getElementType();
7084
7085 if (auto floatPack = llvm::dyn_cast<DenseFPElementsAttr>(sourceConstant)) {
7086 if (floatPack.isSplat()) {
7087 auto splat = floatPack.getSplatValue<FloatAttr>();
7088
7089 // Casting fp16 into fp32.
7090 if (srcElemType.isF16() && dstElemType.isF32()) {
7091 uint32_t bits = static_cast<uint32_t>(
7092 splat.getValue().bitcastToAPInt().getZExtValue());
7093 // Duplicate the 16-bit pattern.
7094 bits = (bits << 16) | (bits & 0xffff);
7095 APInt intBits(32, bits);
7096 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
7097 return DenseElementsAttr::get(getResultVectorType(), floatBits);
7098 }
7099 }
7100 }
7101
7102 if (auto intPack = llvm::dyn_cast<DenseIntElementsAttr>(sourceConstant)) {
7103 if (intPack.isSplat()) {
7104 auto splat = intPack.getSplatValue<IntegerAttr>();
7105
7106 if (llvm::isa<IntegerType>(dstElemType) && srcElemType.isIntOrFloat()) {
7107 uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
7108 uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
7109
7110 // Casting to a larger integer bit width.
7111 if (dstBitWidth > srcBitWidth && dstBitWidth % srcBitWidth == 0) {
7112 APInt intBits = splat.getValue().zext(dstBitWidth);
7113
7114 // Duplicate the lower width element.
7115 for (uint64_t i = 0; i < dstBitWidth / srcBitWidth - 1; i++)
7116 intBits = (intBits << srcBitWidth) | intBits;
7117 return DenseElementsAttr::get(getResultVectorType(), intBits);
7118 }
7119 }
7120 }
7121 }
7122
7123 return {};
7124}
7125
7126std::optional<SmallVector<int64_t, 4>> BitCastOp::getShapeForUnroll() {
7127 return llvm::to_vector<4>(getResultVectorType().getShape());
7128}
7129
7130//===----------------------------------------------------------------------===//
7131// TypeCastOp
7132//===----------------------------------------------------------------------===//
7133
7134static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
7135 auto vectorType = llvm::dyn_cast<VectorType>(memRefType.getElementType());
7136 SmallVector<int64_t, 8> res(memRefType.getShape());
7137 if (vectorType)
7138 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
7139 return res;
7140}
7141
7142/// Build the canonical memRefType with a single vector.
7143/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
7144void TypeCastOp::build(OpBuilder &builder, OperationState &result,
7145 Value source) {
7146 result.addOperands(source);
7147 MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
7148 VectorType vectorType =
7149 VectorType::get(extractShape(memRefType),
7151 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
7152 memRefType.getMemorySpace()));
7153}
7154
7155LogicalResult TypeCastOp::verify() {
7156 MemRefType canonicalType = getMemRefType().canonicalizeStridedLayout();
7157 if (!canonicalType.getLayout().isIdentity())
7158 return emitOpError("expects operand to be a memref with identity layout");
7159 if (!getResultMemRefType().getLayout().isIdentity())
7160 return emitOpError("expects result to be a memref with identity layout");
7161 if (getResultMemRefType().getMemorySpace() !=
7162 getMemRefType().getMemorySpace())
7163 return emitOpError("expects result in same memory space");
7164
7165 auto sourceType = getMemRefType();
7166 auto resultType = getResultMemRefType();
7167 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
7169 return emitOpError(
7170 "expects result and operand with same underlying scalar type: ")
7171 << resultType;
7172 if (extractShape(sourceType) != extractShape(resultType))
7173 return emitOpError(
7174 "expects concatenated result and operand shapes to be equal: ")
7175 << resultType;
7176 return success();
7177}
7178
7179//===----------------------------------------------------------------------===//
7180// TransposeOp
7181//===----------------------------------------------------------------------===//
7182
7183void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
7184 Value vector, ArrayRef<int64_t> permutation) {
7185 VectorType vt = llvm::cast<VectorType>(vector.getType());
7186 SmallVector<int64_t, 4> transposedShape(vt.getRank());
7187 SmallVector<bool, 4> transposedScalableDims(vt.getRank());
7188 for (unsigned i = 0; i < permutation.size(); ++i) {
7189 transposedShape[i] = vt.getShape()[permutation[i]];
7190 transposedScalableDims[i] = vt.getScalableDims()[permutation[i]];
7191 }
7192
7193 result.addOperands(vector);
7194 result.addTypes(VectorType::get(transposedShape, vt.getElementType(),
7195 transposedScalableDims));
7196 result.addAttribute(TransposeOp::getPermutationAttrName(result.name),
7197 builder.getDenseI64ArrayAttr(permutation));
7198}
7199
7200OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
7201 // Eliminate splat constant transpose ops.
7202 if (auto splat =
7203 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
7204 return splat.reshape(getResultVectorType());
7205
7206 // Eliminate poison transpose ops.
7207 if (matchPattern(adaptor.getVector(), ub::m_Poison()))
7208 return ub::PoisonAttr::get(getContext());
7209
7210 // Eliminate identity transposes, and more generally any transposes that
7211 // preserves the shape without permuting elements.
7212 //
7213 // Examples of what to fold:
7214 // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
7215 // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
7216 // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
7217 //
7218 // Example of what NOT to fold:
7219 // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
7220 //
7221 if (getSourceVectorType() == getResultVectorType() &&
7222 isOrderPreserving(*this))
7223 return getVector();
7224
7225 return {};
7226}
7227
7228LogicalResult vector::TransposeOp::verify() {
7229 VectorType vectorType = getSourceVectorType();
7230 VectorType resultType = getResultVectorType();
7231 int64_t rank = resultType.getRank();
7232 if (vectorType.getRank() != rank)
7233 return emitOpError("vector result rank mismatch: ") << rank;
7234 // Verify transposition array.
7235 ArrayRef<int64_t> perm = getPermutation();
7236 int64_t size = perm.size();
7237 if (rank != size)
7238 return emitOpError("transposition length mismatch: ") << size;
7239 SmallVector<bool, 8> seen(rank, false);
7240 for (const auto &ta : llvm::enumerate(perm)) {
7241 if (ta.value() < 0 || ta.value() >= rank)
7242 return emitOpError("transposition index out of range: ") << ta.value();
7243 if (seen[ta.value()])
7244 return emitOpError("duplicate position index: ") << ta.value();
7245 seen[ta.value()] = true;
7246 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(ta.value()))
7247 return emitOpError("dimension size mismatch at: ") << ta.value();
7248 }
7249 return success();
7250}
7251
7252std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
7253 return llvm::to_vector<4>(getResultVectorType().getShape());
7254}
7255
7256void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
7257 SetIntRangeFn setResultRanges) {
7258 setResultRanges(getResult(), argRanges.front());
7259}
7260
7261namespace {
7262
7263// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
7264class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
7265public:
7266 using Base::Base;
7267
7268 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7269 PatternRewriter &rewriter) const override {
7270 // Composes two permutations: result[i] = permutation1[permutation2[i]].
7271 auto composePermutations = [](ArrayRef<int64_t> permutation1,
7272 ArrayRef<int64_t> permutation2) {
7273 SmallVector<int64_t, 4> result;
7274 for (auto index : permutation2)
7275 result.push_back(permutation1[index]);
7276 return result;
7277 };
7278
7279 // Return if the input of 'transposeOp' is not defined by another transpose.
7280 vector::TransposeOp parentTransposeOp =
7281 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
7282 if (!parentTransposeOp)
7283 return failure();
7284
7285 SmallVector<int64_t, 4> permutation = composePermutations(
7286 parentTransposeOp.getPermutation(), transposeOp.getPermutation());
7287 // Replace 'transposeOp' with a new transpose operation.
7288 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
7289 transposeOp, transposeOp.getResult().getType(),
7290 parentTransposeOp.getVector(), permutation);
7291 return success();
7292 }
7293};
7294
7295/// Replace transpose(splat-like(v)) with broadcast(v)
7296class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
7297public:
7298 using Base::Base;
7299
7300 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7301 PatternRewriter &rewriter) const override {
7302 Value splat = getScalarSplatSource(transposeOp.getVector());
7303 if (!splat)
7304 return failure();
7305
7306 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
7307 transposeOp, transposeOp.getResultVectorType(), splat);
7308 return success();
7309 }
7310};
7311
7312/// Folds transpose(create_mask) into a new transposed create_mask.
7313class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
7314public:
7315 using Base::Base;
7316
7317 LogicalResult matchAndRewrite(TransposeOp transpOp,
7318 PatternRewriter &rewriter) const override {
7319 Value transposeSrc = transpOp.getVector();
7320 auto createMaskOp = transposeSrc.getDefiningOp<vector::CreateMaskOp>();
7321 auto constantMaskOp = transposeSrc.getDefiningOp<vector::ConstantMaskOp>();
7322 if (!createMaskOp && !constantMaskOp)
7323 return failure();
7324
7325 // Get the transpose permutation and apply it to the vector.create_mask or
7326 // vector.constant_mask operands.
7327 ArrayRef<int64_t> permutation = transpOp.getPermutation();
7328
7329 if (createMaskOp) {
7330 auto maskOperands = createMaskOp.getOperands();
7331 SmallVector<Value> newOperands(maskOperands.begin(), maskOperands.end());
7332 applyPermutationToVector(newOperands, permutation);
7333
7334 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
7335 transpOp, transpOp.getResultVectorType(), newOperands);
7336 return success();
7337 }
7338
7339 // ConstantMaskOp case.
7340 auto maskDimSizes = constantMaskOp.getMaskDimSizes();
7341 auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
7342
7343 rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
7344 transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
7345 return success();
7346 }
7347};
7348
7349/// Folds transpose(shape_cast) into a new shape_cast.
7350class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
7351public:
7352 using Base::Base;
7353
7354 LogicalResult matchAndRewrite(TransposeOp transposeOp,
7355 PatternRewriter &rewriter) const override {
7356 auto shapeCastOp =
7357 transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
7358 if (!shapeCastOp)
7359 return failure();
7360 if (!isOrderPreserving(transposeOp))
7361 return failure();
7362
7363 VectorType resultType = transposeOp.getType();
7364
7365 // We don't need to check isValidShapeCast at this point, because it is
7366 // guaranteed that merging the transpose into the the shape_cast is a valid
7367 // shape_cast, because the transpose just inserts/removes ones.
7368
7369 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
7370 shapeCastOp.getSource());
7371 return success();
7372 }
7373};
7374
7375/// Folds transpose(from_elements(...)) into a new from_elements with permuted
7376/// operands matching the transposed shape.
7377///
7378/// Example:
7379///
7380/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
7381/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
7382/// vector<3x2xi32>
7383///
7384/// becomes ->
7385///
7386/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
7387/// vector<3x2xi32>
7388///
7389class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
7390public:
7391 using Base::Base;
7392 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
7393 PatternRewriter &rewriter) const override {
7394 auto fromElementsOp =
7395 transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
7396 if (!fromElementsOp)
7397 return failure();
7398
7399 VectorType srcTy = fromElementsOp.getDest().getType();
7400 VectorType dstTy = transposeOp.getType();
7401
7402 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
7403 int64_t rank = srcTy.getRank();
7404
7405 // Build inverse permutation to map destination indices back to source.
7406 SmallVector<int64_t> inversePerm(rank, 0);
7407 for (int64_t i = 0; i < rank; ++i)
7408 inversePerm[permutation[i]] = i;
7409
7410 ArrayRef<int64_t> srcShape = srcTy.getShape();
7411 ArrayRef<int64_t> dstShape = dstTy.getShape();
7412 SmallVector<int64_t> srcIdx(rank, 0);
7413 SmallVector<int64_t> dstIdx(rank, 0);
7414 SmallVector<int64_t> srcStrides = computeStrides(srcShape);
7415 SmallVector<int64_t> dstStrides = computeStrides(dstShape);
7416
7417 auto elementsOld = fromElementsOp.getElements();
7418 SmallVector<Value> elementsNew;
7419 int64_t dstNumElements = dstTy.getNumElements();
7420 elementsNew.reserve(dstNumElements);
7421
7422 // For each element in destination row-major order, pick the corresponding
7423 // source element.
7424 for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
7425 // Pick the destination element index.
7426 dstIdx = delinearize(linearIdx, dstStrides);
7427 // Map the destination element index to the source element index.
7428 for (int64_t j = 0; j < rank; ++j)
7429 srcIdx[j] = dstIdx[inversePerm[j]];
7430 // Linearize the source element index.
7431 int64_t srcLin = linearize(srcIdx, srcStrides);
7432 // Add the source element to the new elements.
7433 elementsNew.push_back(elementsOld[srcLin]);
7434 }
7435
7436 rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
7437 elementsNew);
7438 return success();
7439 }
7440};
7441
7442/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
7443/// 'order preserving', where 'order preserving' means the flattened
7444/// inputs and outputs of the transpose have identical (numerical) values.
7445///
7446/// Example:
7447/// ```
7448/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
7449/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
7450/// to vector<8x1xi32>
7451/// ```
7452/// can be rewritten as the equivalent
7453/// ```
7454/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
7455/// ```
7456/// The algorithm works by partitioning dimensions into groups that can be
7457/// locally permuted while preserving order, and checks that the transpose
7458/// only permutes within these groups.
7459///
7460/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
7461/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
7462/// broadcasting from 1x1x4x1x1x7.
7463/// ^^^ ^ ^^^ ^
7464/// groups: 0 1 2 3
7465/// Order preserving permutations for this example are ones that only permute
7466/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
7467class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
7468public:
7469 using Base::Base;
7470 FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
7471 : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
7472
7473 LogicalResult matchAndRewrite(vector::TransposeOp transpose,
7474 PatternRewriter &rewriter) const override {
7475
7476 vector::BroadcastOp broadcast =
7477 transpose.getVector().getDefiningOp<vector::BroadcastOp>();
7478 if (!broadcast) {
7479 return rewriter.notifyMatchFailure(transpose,
7480 "not preceded by a broadcast");
7481 }
7482
7483 auto inputType = dyn_cast<VectorType>(broadcast.getSourceType());
7484 VectorType outputType = transpose.getResultVectorType();
7485
7486 // transpose(broadcast(scalar)) -> broadcast(scalar) is always valid
7487 bool inputIsScalar = !inputType;
7488 if (inputIsScalar) {
7489 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7490 broadcast.getSource());
7491 return success();
7492 }
7493
7494 ArrayRef<int64_t> permutation = transpose.getPermutation();
7495 ArrayRef<int64_t> inputShape = inputType.getShape();
7496 int64_t inputRank = inputType.getRank();
7497 int64_t outputRank = transpose.getType().getRank();
7498 int64_t deltaRank = outputRank - inputRank;
7499
7500 int low = 0;
7501 for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
7502 bool notOne = inputShape[inputIndex] != 1;
7503 bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
7504 bool groupEndFound = notOne || prevNotOne;
7505 if (groupEndFound) {
7506 int high = inputIndex + deltaRank;
7507 // Return failure if not all permutation destinations for indices in
7508 // [low, high) are in [low, high), i.e. the permutation is not local to
7509 // the group.
7510 for (int i = low; i < high; ++i) {
7511 if (permutation[i] < low || permutation[i] >= high) {
7512 return rewriter.notifyMatchFailure(
7513 transpose, "permutation not local to group");
7514 }
7515 }
7516 low = high;
7517 }
7518 }
7519
7520 // We don't need to check the final group [low, outputRank) because if it is
7521 // not locally bound, there must be a preceding group that already failed
7522 // the check (impossible to have just 1 non-locally bound group).
7523
7524 // The preceding logic also ensures that at this point, the output of the
7525 // transpose is definitely broadcastable from the input shape, assert so:
7526 assert(vector::isBroadcastableTo(inputType, outputType) ==
7527 vector::BroadcastableToResult::Success &&
7528 "not broadcastable directly to transpose output");
7529
7530 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
7531 broadcast.getSource());
7532
7533 return success();
7534 }
7535};
7536
7537} // namespace
7538
7539void vector::TransposeOp::getCanonicalizationPatterns(
7540 RewritePatternSet &results, MLIRContext *context) {
7541 results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
7542 FoldTransposeSplat, FoldTransposeFromElements,
7543 FoldTransposeBroadcast>(context);
7544}
7545
7546//===----------------------------------------------------------------------===//
7547// ConstantMaskOp
7548//===----------------------------------------------------------------------===//
7549
7550void ConstantMaskOp::build(OpBuilder &builder, OperationState &result,
7551 VectorType type, ConstantMaskKind kind) {
7552 assert(kind == ConstantMaskKind::AllTrue ||
7553 kind == ConstantMaskKind::AllFalse);
7554 build(builder, result, type,
7555 kind == ConstantMaskKind::AllTrue
7556 ? type.getShape()
7557 : SmallVector<int64_t>(type.getRank(), 0));
7558}
7559
7560LogicalResult ConstantMaskOp::verify() {
7561 auto resultType = llvm::cast<VectorType>(getResult().getType());
7562 // Check the corner case of 0-D vectors first.
7563 if (resultType.getRank() == 0) {
7564 if (getMaskDimSizes().size() != 1)
7565 return emitError("array attr must have length 1 for 0-D vectors");
7566 auto dim = getMaskDimSizes()[0];
7567 if (dim != 0 && dim != 1)
7568 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
7569 return success();
7570 }
7571
7572 // Verify that array attr size matches the rank of the vector result.
7573 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
7574 return emitOpError(
7575 "must specify array attr of size equal vector result rank");
7576 // Verify that each array attr element is in bounds of corresponding vector
7577 // result dimension size.
7578 auto resultShape = resultType.getShape();
7579 auto resultScalableDims = resultType.getScalableDims();
7580 ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
7581 for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
7582 if (maskDimSize < 0 || maskDimSize > resultShape[index])
7583 return emitOpError(
7584 "array attr of size out of bounds of vector result dimension size");
7585 if (resultScalableDims[index] && maskDimSize != 0 &&
7586 maskDimSize != resultShape[index])
7587 return emitOpError(
7588 "only supports 'none set' or 'all set' scalable dimensions");
7589 }
7590 // Verify that if one mask dim size is zero, they all should be zero (because
7591 // the mask region is a conjunction of each mask dimension interval).
7592 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
7593 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
7594 if (anyZeros && !allZeros)
7595 return emitOpError("expected all mask dim sizes to be zeros, "
7596 "as a result of conjunction with zero mask dim");
7597 return success();
7598}
7599
7600bool ConstantMaskOp::isAllOnesMask() {
7601 auto resultType = getVectorType();
7602 // Check the corner case of 0-D vectors first.
7603 if (resultType.getRank() == 0) {
7604 assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
7605 return getMaskDimSizes()[0] == 1;
7606 }
7607 for (const auto [resultSize, maskDimSize] :
7608 llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
7609 if (maskDimSize < resultSize)
7610 return false;
7611 }
7612 return true;
7613}
7614
7615OpFoldResult ConstantMaskOp::fold(FoldAdaptor adaptor) {
7616 ArrayRef<int64_t> bounds = getMaskDimSizes();
7617 ArrayRef<int64_t> vectorSizes = getVectorType().getShape();
7618
7619 auto createBoolSplat = [&](bool x) {
7620 return SplatElementsAttr::get(getVectorType(),
7622 };
7623
7624 // Check the corner case of 0-D vectors first.
7625 if (vectorSizes.empty()) {
7626 assert(bounds.size() == 1 && "invalid sizes for zero rank mask");
7627 return createBoolSplat(bounds[0] == 1);
7628 }
7629 // Fold vector.constant_mask to splat if possible.
7630 if (bounds == vectorSizes)
7631 return createBoolSplat(true);
7632 if (llvm::all_of(bounds, [](int64_t x) { return x == 0; }))
7633 return createBoolSplat(false);
7634 return OpFoldResult();
7635}
7636
7637//===----------------------------------------------------------------------===//
7638// CreateMaskOp
7639//===----------------------------------------------------------------------===//
7640
7641void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
7642 VectorType type,
7643 ArrayRef<OpFoldResult> mixedOperands) {
7644 SmallVector<Value> operands =
7645 getValueOrCreateConstantIndexOp(builder, result.location, mixedOperands);
7646 build(builder, result, type, operands);
7647}
7648
7649LogicalResult CreateMaskOp::verify() {
7650 auto vectorType = llvm::cast<VectorType>(getResult().getType());
7651 // Verify that an operand was specified for each result vector each dimension.
7652 if (vectorType.getRank() == 0) {
7653 if (getNumOperands() != 1)
7654 return emitOpError(
7655 "must specify exactly one operand for 0-D create_mask");
7656 } else if (getNumOperands() !=
7657 llvm::cast<VectorType>(getResult().getType()).getRank()) {
7658 return emitOpError(
7659 "must specify an operand for each result vector dimension");
7660 }
7661 return success();
7662}
7663
7664namespace {
7665
7666/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
7667///
7668/// Ex 1:
7669/// %c2 = arith.constant 2 : index
7670/// %c3 = arith.constant 3 : index
7671/// %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
7672/// Becomes:
7673/// vector.constant_mask [3, 2] : vector<4x3xi1>
7674///
7675/// Ex 2:
7676/// %c_neg_1 = arith.constant -1 : index
7677/// %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
7678/// becomes:
7679/// vector.constant_mask [0] : vector<[8]xi1>
7680///
7681/// Ex 3:
7682/// %c8 = arith.constant 8 : index
7683/// %c16 = arith.constant 16 : index
7684/// %0 = vector.vscale
7685/// %1 = arith.muli %0, %c16 : index
7686/// %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
7687/// becomes:
7688/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
7689class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
7690public:
7691 using Base::Base;
7692
7693 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
7694 PatternRewriter &rewriter) const override {
7695 VectorType maskType = createMaskOp.getVectorType();
7696 ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
7697 ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
7698
7699 // Special case: Rank zero shape.
7700 constexpr std::array<int64_t, 1> rankZeroShape{1};
7701 constexpr std::array<bool, 1> rankZeroScalableDims{false};
7702 if (maskType.getRank() == 0) {
7703 maskTypeDimSizes = rankZeroShape;
7704 maskTypeDimScalableFlags = rankZeroScalableDims;
7705 }
7706
7707 // Determine if this CreateMaskOp can be folded to a ConstantMaskOp and
7708 // collect the `constantDims` (for the ConstantMaskOp).
7709 SmallVector<int64_t, 4> constantDims;
7710 for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
7711 if (auto intSize = getConstantIntValue(dimSize)) {
7712 // Constant value.
7713 // If the mask dim is non-scalable this can be any value.
7714 // If the mask dim is scalable only zero (all-false) is supported.
7715 if (maskTypeDimScalableFlags[i] && intSize >= 0)
7716 return failure();
7717 constantDims.push_back(*intSize);
7718 } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
7719 // Constant vscale multiple (e.g. 4 x vscale).
7720 // Must be all-true to fold to a ConstantMask.
7721 if (vscaleMultiplier < maskTypeDimSizes[i])
7722 return failure();
7723 constantDims.push_back(*vscaleMultiplier);
7724 } else {
7725 return failure();
7726 }
7727 }
7728
7729 // Clamp values to constant_mask bounds.
7730 for (auto [value, maskDimSize] : llvm::zip(constantDims, maskTypeDimSizes))
7731 value = std::clamp<int64_t>(value, 0, maskDimSize);
7732
7733 // If one of dim sizes is zero, set all dims to zero.
7734 if (llvm::is_contained(constantDims, 0))
7735 constantDims.assign(constantDims.size(), 0);
7736
7737 // Replace 'createMaskOp' with ConstantMaskOp.
7738 rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, maskType,
7739 constantDims);
7740 return success();
7741 }
7742};
7743
7744} // namespace
7745
7746void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
7747 MLIRContext *context) {
7748 results.add<CreateMaskFolder>(context);
7749}
7750
7751//===----------------------------------------------------------------------===//
7752// MaskOp
7753//===----------------------------------------------------------------------===//
7754
7755void MaskOp::build(
7756 OpBuilder &builder, OperationState &result, Value mask,
7757 Operation *maskableOp,
7758 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7759 assert(maskRegionBuilder &&
7760 "builder callback for 'maskRegion' must be present");
7761
7762 result.addOperands(mask);
7763 OpBuilder::InsertionGuard guard(builder);
7764 Region *maskRegion = result.addRegion();
7765 builder.createBlock(maskRegion);
7766 maskRegionBuilder(builder, maskableOp);
7767}
7768
7769void MaskOp::build(
7770 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7771 Value mask, Operation *maskableOp,
7772 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7773 build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
7774 maskRegionBuilder);
7775}
7776
7777void MaskOp::build(
7778 OpBuilder &builder, OperationState &result, TypeRange resultTypes,
7779 Value mask, Value passthru, Operation *maskableOp,
7780 function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
7781 build(builder, result, mask, maskableOp, maskRegionBuilder);
7782 if (passthru)
7783 result.addOperands(passthru);
7784 result.addTypes(resultTypes);
7785}
7786
7787ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
7788 // Create the op region.
7789 result.regions.reserve(1);
7790 Region &maskRegion = *result.addRegion();
7791
7792 auto &builder = parser.getBuilder();
7793
7794 // Parse all the operands.
7795 OpAsmParser::UnresolvedOperand mask;
7796 if (parser.parseOperand(mask))
7797 return failure();
7798
7799 // Optional passthru operand.
7800 OpAsmParser::UnresolvedOperand passthru;
7801 ParseResult parsePassthru = parser.parseOptionalComma();
7802 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
7803 return failure();
7804
7805 // Parse op region.
7806 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
7807 return failure();
7808
7809 MaskOp::ensureTerminator(maskRegion, builder, result.location);
7810
7811 // Parse the optional attribute list.
7812 if (parser.parseOptionalAttrDict(result.attributes))
7813 return failure();
7814
7815 // Parse all the types.
7816 Type maskType;
7817 if (parser.parseColonType(maskType))
7818 return failure();
7819
7820 SmallVector<Type> resultTypes;
7821 if (parser.parseOptionalArrowTypeList(resultTypes))
7822 return failure();
7823 result.types.append(resultTypes);
7824
7825 // Resolve operands.
7826 if (parser.resolveOperand(mask, maskType, result.operands))
7827 return failure();
7828
7829 if (parsePassthru.succeeded()) {
7830 if (resultTypes.empty())
7831 return parser.emitError(
7832 parser.getNameLoc(),
7833 "expects a result if passthru operand is provided");
7834
7835 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
7836 return failure();
7837 }
7838
7839 return success();
7840}
7841
7842void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
7843 p << " " << getMask();
7844 if (getPassthru())
7845 p << ", " << getPassthru();
7846
7847 // Print single masked operation and skip terminator.
7848 p << " { ";
7849 Block *singleBlock = &getMaskRegion().getBlocks().front();
7850 if (singleBlock && !singleBlock->getOperations().empty())
7851 p.printCustomOrGenericOp(&singleBlock->front());
7852 p << " }";
7853
7854 p.printOptionalAttrDict(getOperation()->getAttrs());
7855
7856 p << " : " << getMask().getType();
7857 if (getNumResults() > 0)
7858 p << " -> " << getResultTypes();
7859}
7860
7861void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
7862 // 1. For an empty `vector.mask`, create a default terminator.
7863 if (region.empty() || region.front().empty()) {
7864 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7865 MaskOp>::ensureTerminator(region, builder, loc);
7866 return;
7867 }
7868
7869 // 2. For a non-empty `vector.mask` with an explicit terminator, do nothing.
7870 Block &block = region.front();
7871 if (isa<vector::YieldOp>(block.back()))
7872 return;
7873
7874 // 3. For a non-empty `vector.mask` without an explicit terminator:
7875
7876 // Create default terminator if the number of masked operations is not
7877 // one. This case will trigger a verification failure.
7878 if (block.getOperations().size() != 1) {
7879 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
7880 MaskOp>::ensureTerminator(region, builder, loc);
7881 return;
7882 }
7883
7884 // Create a terminator that yields the results from the masked operation.
7885 OpBuilder opBuilder(builder.getContext());
7886 Operation *maskedOp = &block.front();
7887 opBuilder.setInsertionPointToEnd(&block);
7888 vector::YieldOp::create(opBuilder, loc, maskedOp->getResults());
7889}
7890
7891LogicalResult MaskOp::verify() {
7892 // Structural checks.
7893 Block &block = getMaskRegion().getBlocks().front();
7894 if (block.getOperations().empty())
7895 return emitOpError("expects a terminator within the mask region");
7896
7897 unsigned numMaskRegionOps = block.getOperations().size();
7898 if (numMaskRegionOps > 2)
7899 return emitOpError("expects only one operation to mask");
7900
7901 // Terminator checks.
7902 auto terminator = dyn_cast<vector::YieldOp>(block.back());
7903 if (!terminator)
7904 return emitOpError("expects a terminator within the mask region");
7905
7906 if (terminator->getNumOperands() != getNumResults())
7907 return emitOpError(
7908 "expects number of results to match mask region yielded values");
7909
7910 // Empty vector.mask. Nothing else to check.
7911 if (numMaskRegionOps == 1)
7912 return success();
7913
7914 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
7915 if (!maskableOp)
7916 return emitOpError("expects a MaskableOpInterface within the mask region");
7917
7918 // Result checks.
7919 if (maskableOp->getNumResults() != getNumResults())
7920 return emitOpError("expects number of results to match maskable operation "
7921 "number of results");
7922
7923 if (!llvm::equal(maskableOp->getResults(), terminator.getOperands()))
7924 return emitOpError("expects all the results from the MaskableOpInterface "
7925 "to match all the values returned by the terminator");
7926
7927 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
7928 return emitOpError(
7929 "expects result type to match maskable operation result type");
7930
7931 if (llvm::count_if(maskableOp->getResultTypes(),
7932 [](Type t) { return llvm::isa<VectorType>(t); }) > 1)
7933 return emitOpError("multiple vector results not supported");
7934
7935 // Mask checks.
7936 Type expectedMaskType = maskableOp.getExpectedMaskType();
7937 if (getMask().getType() != expectedMaskType)
7938 return emitOpError("expects a ")
7939 << expectedMaskType << " mask for the maskable operation";
7940
7941 // Passthru checks.
7942 Value passthru = getPassthru();
7943 if (passthru) {
7944 if (!maskableOp.supportsPassthru())
7945 return emitOpError(
7946 "doesn't expect a passthru argument for this maskable operation");
7947
7948 if (maskableOp->getNumResults() != 1)
7949 return emitOpError("expects result when passthru argument is provided");
7950
7951 if (passthru.getType() != maskableOp->getResultTypes()[0])
7952 return emitOpError("expects passthru type to match result type");
7953 }
7954
7955 return success();
7956}
7957
7958/// Folds empty `vector.mask` with no passthru operand and with or without
7959/// return values. For example:
7960///
7961/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
7962/// vector<8xi1> -> vector<8xf32>
7963/// %1 = user_op %0 : vector<8xf32>
7964///
7965/// becomes:
7966///
7967/// %0 = user_op %a : vector<8xf32>
7968///
7969/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
7970/// as it requires creating new operations.
7971
7972static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
7973 SmallVectorImpl<OpFoldResult> &results) {
7974 if (!maskOp.isEmpty() || maskOp.hasPassthru())
7975 return failure();
7976
7977 Block *block = maskOp.getMaskBlock();
7978 auto terminator = cast<vector::YieldOp>(block->front());
7979 if (terminator.getNumOperands() == 0)
7980 return failure();
7981
7982 // `vector.mask` has results, propagate the results.
7983 llvm::append_range(results, terminator.getOperands());
7984 return success();
7985}
7986
7987LogicalResult MaskOp::fold(FoldAdaptor adaptor,
7988 SmallVectorImpl<OpFoldResult> &results) {
7989 if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
7990 return success();
7991
7992 MaskFormat maskFormat = getMaskFormat(getMask());
7993 if (maskFormat != MaskFormat::AllTrue)
7994 return failure();
7995
7996 // Move maskable operation outside of the `vector.mask` region.
7997 // If there is no maskable op (empty body), the fold cannot proceed; the
7998 // canonicalizer handles this case instead.
7999 Operation *maskableOp = getMaskableOp();
8000 if (!maskableOp)
8001 return failure();
8002 maskableOp->dropAllUses();
8003 maskableOp->moveBefore(getOperation());
8004
8005 llvm::append_range(results, maskableOp->getResults());
8006 return success();
8007}
8008
8009/// Canonialize empty `vector.mask` operations that can't be handled in
8010/// `VectorMask::fold` as they require creating new operations.
8011///
8012/// Example 1: Empty `vector.mask` with passthru operand.
8013///
8014/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
8015/// vector<8xi1> -> vector<8xf32>
8016///
8017/// becomes:
8018///
8019/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
8020///
8021class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
8022 using Base::Base;
8023
8024 LogicalResult matchAndRewrite(MaskOp maskOp,
8025 PatternRewriter &rewriter) const override {
8026 if (!maskOp.isEmpty())
8027 return failure();
8028
8029 if (!maskOp.hasPassthru())
8030 return failure();
8031
8032 // arith.select with a vector condition requires the value types to be
8033 // vectors of the same shape. Since vector.mask always has a vector mask
8034 // type, bail out when any result type doesn't match the mask shape to
8035 // avoid creating invalid IR.
8036 VectorType maskType = maskOp.getMask().getType();
8037 for (Type resultType : maskOp.getResultTypes()) {
8038 auto vecResultType = dyn_cast<VectorType>(resultType);
8039 if (!vecResultType || vecResultType.getShape() != maskType.getShape())
8040 return failure();
8041 }
8042
8043 Block *block = maskOp.getMaskBlock();
8044 auto terminator = cast<vector::YieldOp>(block->front());
8045 assert(terminator.getNumOperands() == 1 &&
8046 "expected one result when passthru is provided");
8047
8048 rewriter.replaceOpWithNewOp<arith::SelectOp>(
8049 maskOp, maskOp.getResultTypes(), maskOp.getMask(),
8050 terminator.getOperand(0), maskOp.getPassthru());
8051
8052 return success();
8053 }
8054};
8055
8056void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
8057 MLIRContext *context) {
8058 results.add<CanonializeEmptyMaskOp>(context);
8059}
8060
8061// MaskingOpInterface definitions.
8062
8063/// Returns the operation masked by this 'vector.mask'.
8064Operation *MaskOp::getMaskableOp() {
8065 Block *block = getMaskBlock();
8066 if (block->getOperations().size() < 2)
8067 return nullptr;
8068
8069 return &block->front();
8070}
8071
8072/// Returns true if 'vector.mask' has a passthru value.
8073bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
8074
8075//===----------------------------------------------------------------------===//
8076// ScanOp
8077//===----------------------------------------------------------------------===//
8078
8079LogicalResult ScanOp::verify() {
8080 VectorType srcType = getSourceType();
8081 VectorType initialType = getInitialValueType();
8082 // Check reduction dimension < rank.
8083 int64_t srcRank = srcType.getRank();
8084 int64_t reductionDim = getReductionDim();
8085 if (reductionDim >= srcRank)
8086 return emitOpError("reduction dimension ")
8087 << reductionDim << " has to be less than " << srcRank;
8088
8089 // Check that rank(initial_value) = rank(src) - 1.
8090 int64_t initialValueRank = initialType.getRank();
8091 if (initialValueRank != srcRank - 1)
8092 return emitOpError("initial value rank ")
8093 << initialValueRank << " has to be equal to " << srcRank - 1;
8094
8095 // Check shapes of initial value and src.
8096 ArrayRef<int64_t> srcShape = srcType.getShape();
8097 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
8098 SmallVector<int64_t> expectedShape;
8099 for (int i = 0; i < srcRank; i++) {
8100 if (i != reductionDim)
8101 expectedShape.push_back(srcShape[i]);
8102 }
8103 if (!llvm::equal(initialValueShapes, expectedShape)) {
8104 return emitOpError("incompatible input/initial value shapes");
8105 }
8106
8107 // Verify supported reduction kind.
8108 Type eltType = getDestType().getElementType();
8109 if (!isSupportedCombiningKind(getKind(), eltType))
8110 return emitOpError("unsupported reduction type ")
8111 << eltType << " for kind '" << stringifyCombiningKind(getKind())
8112 << "'";
8113
8114 return success();
8115}
8116
8118 RewritePatternSet &patterns, PatternBenefit benefit) {
8119 patterns
8120 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
8121 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
8122 StridedSliceConstantMaskFolder, TransposeFolder>(
8123 patterns.getContext(), benefit);
8124}
8125
8126Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
8127 CombiningKind kind, Value v1, Value acc,
8128 arith::FastMathFlagsAttr fastmath,
8129 Value mask) {
8130 Type t1 = getElementTypeOrSelf(v1.getType());
8131 Type tAcc = getElementTypeOrSelf(acc.getType());
8132 Value result;
8133
8134 switch (kind) {
8135 case CombiningKind::ADD:
8136 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
8137 result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
8138 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8139 result = b.createOrFold<arith::AddFOp>(loc, v1, acc, fastmath);
8140 else
8141 llvm_unreachable("invalid value types for ADD reduction");
8142 break;
8143 case CombiningKind::AND:
8144 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8145 result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
8146 break;
8147 case CombiningKind::MAXNUMF:
8148 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8149 "expected float values");
8150 result = b.createOrFold<arith::MaxNumFOp>(loc, v1, acc, fastmath);
8151 break;
8152 case CombiningKind::MAXIMUMF:
8153 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8154 "expected float values");
8155 result = b.createOrFold<arith::MaximumFOp>(loc, v1, acc, fastmath);
8156 break;
8157 case CombiningKind::MINNUMF:
8158 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8159 "expected float values");
8160 result = b.createOrFold<arith::MinNumFOp>(loc, v1, acc, fastmath);
8161 break;
8162 case CombiningKind::MINIMUMF:
8163 assert(llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc) &&
8164 "expected float values");
8165 result = b.createOrFold<arith::MinimumFOp>(loc, v1, acc, fastmath);
8166 break;
8167 case CombiningKind::MAXSI:
8168 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8169 result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
8170 break;
8171 case CombiningKind::MINSI:
8172 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8173 result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
8174 break;
8175 case CombiningKind::MAXUI:
8176 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8177 result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
8178 break;
8179 case CombiningKind::MINUI:
8180 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8181 result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
8182 break;
8183 case CombiningKind::MUL:
8184 if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
8185 result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
8186 else if (llvm::isa<FloatType>(t1) && llvm::isa<FloatType>(tAcc))
8187 result = b.createOrFold<arith::MulFOp>(loc, v1, acc, fastmath);
8188 else
8189 llvm_unreachable("invalid value types for MUL reduction");
8190 break;
8191 case CombiningKind::OR:
8192 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8193 result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
8194 break;
8195 case CombiningKind::XOR:
8196 assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
8197 result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
8198 break;
8199 };
8200
8201 assert(result && "unknown CombiningKind");
8202 return selectPassthru(b, mask, result, acc);
8203}
8204
8205//===----------------------------------------------------------------------===//
8206// StepOp
8207//===----------------------------------------------------------------------===//
8208
8209void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
8210 SetIntRangeFn setResultRanges) {
8211 auto resultType = cast<VectorType>(getType());
8212 if (resultType.isScalable()) {
8213 return;
8214 }
8215 unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
8216 APInt zero(bitwidth, 0);
8217 APInt high(bitwidth, resultType.getDimSize(0) - 1);
8218 ConstantIntRanges result = {zero, high, zero, high};
8219 setResultRanges(getResult(), result);
8220}
8221
8222namespace {
8223
8224/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
8225/// constant large enough such that the result is the same at all indices.
8226///
8227/// For example, rewrite the 'greater than' comparison below,
8228///
8229/// ```mlir
8230/// %cst = arith.constant dense<7> : vector<3xindex>
8231/// %stp = vector.step : vector<3xindex>
8232/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
8233/// ```
8234///
8235/// as,
8236///
8237/// ```mlir
8238/// %out = arith.constant dense<false> : vector<3xi1>.
8239/// ```
8240///
8241/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
8242/// is false at ALL indices we fold. If the constant was 1, then
8243/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
8244/// conservatively preferring the 'compact' vector.step representation.
8245///
8246/// Note: this folder only works for the case where the constant (`%cst` above)
8247/// is the second operand of the comparison. The arith.cmpi canonicalizer will
8248/// ensure that constants are always second (on the right).
8249struct StepCompareFolder : public OpRewritePattern<StepOp> {
8250 using Base::Base;
8251
8252 LogicalResult matchAndRewrite(StepOp stepOp,
8253 PatternRewriter &rewriter) const override {
8254 const int64_t stepSize = stepOp.getResult().getType().getNumElements();
8255
8256 for (OpOperand &use : stepOp.getResult().getUses()) {
8257 auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
8258 if (!cmpiOp)
8259 continue;
8260
8261 // arith.cmpi canonicalizer makes constants final operands.
8262 const unsigned stepOperandNumber = use.getOperandNumber();
8263 if (stepOperandNumber != 0)
8264 continue;
8265
8266 // Check that operand 1 is a constant.
8267 unsigned constOperandNumber = 1;
8268 Value otherOperand = cmpiOp.getOperand(constOperandNumber);
8269 std::optional<int64_t> maybeConstValue =
8270 getConstantIntValue(otherOperand);
8271 if (!maybeConstValue.has_value())
8272 continue;
8273
8274 int64_t constValue = maybeConstValue.value();
8275 arith::CmpIPredicate pred = cmpiOp.getPredicate();
8276
8277 auto maybeSplat = [&]() -> std::optional<bool> {
8278 // Handle ult (unsigned less than) and uge (unsigned greater equal).
8279 if ((pred == arith::CmpIPredicate::ult ||
8280 pred == arith::CmpIPredicate::uge) &&
8281 stepSize <= constValue)
8282 return pred == arith::CmpIPredicate::ult;
8283
8284 // Handle ule and ugt.
8285 if ((pred == arith::CmpIPredicate::ule ||
8286 pred == arith::CmpIPredicate::ugt) &&
8287 stepSize - 1 <= constValue) {
8288 return pred == arith::CmpIPredicate::ule;
8289 }
8290
8291 // Handle eq and ne.
8292 if ((pred == arith::CmpIPredicate::eq ||
8293 pred == arith::CmpIPredicate::ne) &&
8294 stepSize <= constValue)
8295 return pred == arith::CmpIPredicate::ne;
8296
8297 return std::nullopt;
8298 }();
8299
8300 if (!maybeSplat.has_value())
8301 continue;
8302
8303 rewriter.setInsertionPointAfter(cmpiOp);
8304
8305 auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
8306 if (!type)
8307 continue;
8308
8309 auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
8310 Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
8311 type, boolAttr);
8312
8313 rewriter.replaceOp(cmpiOp, splat);
8314 return success();
8315 }
8316
8317 return failure();
8318 }
8319};
8320} // namespace
8321
8322void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
8323 MLIRContext *context) {
8324 results.add<StepCompareFolder>(context);
8325}
8326
8327//===----------------------------------------------------------------------===//
8328// Vector Masking Utilities
8329//===----------------------------------------------------------------------===//
8330
8331/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
8332/// as masked operation.
8333void mlir::vector::createMaskOpRegion(OpBuilder &builder,
8334 Operation *maskableOp) {
8335 assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
8336 Block *insBlock = builder.getInsertionBlock();
8337 // Create a block and move the op to that block.
8338 insBlock->getOperations().splice(
8339 insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
8340 YieldOp::create(builder, maskableOp->getLoc(), maskableOp->getResults());
8341}
8342
8343/// Creates a vector.mask operation around a maskable operation. Returns the
8344/// vector.mask operation if the mask provided is valid. Otherwise, returns
8345/// the maskable operation itself.
8346Operation *mlir::vector::maskOperation(OpBuilder &builder,
8347 Operation *maskableOp, Value mask,
8348 Value passthru) {
8349 if (!mask)
8350 return maskableOp;
8351 if (passthru)
8352 return MaskOp::create(builder, maskableOp->getLoc(),
8353 maskableOp->getResultTypes(), mask, passthru,
8354 maskableOp, createMaskOpRegion);
8355 return MaskOp::create(builder, maskableOp->getLoc(),
8356 maskableOp->getResultTypes(), mask, maskableOp,
8358}
8359
8360/// Creates a vector select operation that picks values from `newValue` or
8361/// `passthru` for each result vector lane based on `mask`. This utility is used
8362/// to propagate the pass-thru value of vector.mask or for cases where only the
8363/// pass-thru value propagation is needed. VP intrinsics do not support
8364/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
8365/// usually able to match op + select patterns and fold them into a native
8366/// target instructions.
8367Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
8368 Value newValue, Value passthru) {
8369 if (!mask)
8370 return newValue;
8371
8372 return arith::SelectOp::create(builder, newValue.getLoc(), newValue.getType(),
8373 mask, newValue, passthru);
8374}
8375
8376//===----------------------------------------------------------------------===//
8377// InterleaveOp
8378//===----------------------------------------------------------------------===//
8379
8380namespace {
8381
8382/// This folder works on the following round-trip identity:
8383/// interleave(deinterleave(x).even, deinterleave(x).odd) -> x
8384struct InterleaveDeinterleaveFolder : public OpRewritePattern<InterleaveOp> {
8385 using Base::Base;
8386
8387 LogicalResult matchAndRewrite(InterleaveOp interleaveOp,
8388 PatternRewriter &rewriter) const override {
8389 auto lhsDefOp = interleaveOp.getLhs().getDefiningOp<DeinterleaveOp>();
8390 auto rhsDefOp = interleaveOp.getRhs().getDefiningOp<DeinterleaveOp>();
8391 if (!lhsDefOp || !rhsDefOp || lhsDefOp != rhsDefOp)
8392 return failure();
8393 for (auto [idx, operand] : llvm::enumerate(interleaveOp.getOperands())) {
8394 if (cast<OpResult>(operand).getResultNumber() != idx)
8395 return failure();
8396 }
8397 rewriter.replaceOp(interleaveOp, lhsDefOp.getSource());
8398 return success();
8399 }
8400};
8401} // namespace
8402
8403void InterleaveOp::getCanonicalizationPatterns(RewritePatternSet &results,
8404 MLIRContext *context) {
8405 results.add<InterleaveDeinterleaveFolder>(context);
8406}
8407
8408std::optional<SmallVector<int64_t, 4>> InterleaveOp::getShapeForUnroll() {
8409 return llvm::to_vector<4>(getResultVectorType().getShape());
8410}
8411
8412//===----------------------------------------------------------------------===//
8413// DeinterleaveOp
8414//===----------------------------------------------------------------------===//
8415
8416std::optional<SmallVector<int64_t, 4>> DeinterleaveOp::getShapeForUnroll() {
8417 return llvm::to_vector<4>(getResultVectorType().getShape());
8418}
8419
8420//===----------------------------------------------------------------------===//
8421// TableGen'd op method definitions
8422//===----------------------------------------------------------------------===//
8423
8424#define GET_ATTRDEF_CLASSES
8425#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
8426
8427#define GET_OP_CLASSES
8428#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Type getElementType(Type type)
Determine the element type of type.
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
if(!isCopyOut)
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static std::optional< VectorShape > vectorShape(Type type)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static MaskFormat getMaskFormat(Value mask)
Helper method to classify a mask value.
Definition VectorOps.cpp:74
static OpFoldResult foldShuffleIdentityMask(ShuffleOp op)
Fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1.
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp)
Fold the result of chains of ExtractOp in place by simply concatenating the positions.
static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp)
Folds vector.from_elements(vector.to_elements(vector)) into vector.
static bool hasZeroDimVectors(Operation *op)
Returns true if the operation has a 0-D vector type operand or result.
static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op)
static Value foldScalarExtractFromFromElements(ExtractOp extractOp)
Try to fold the extraction of a scalar from a vector defined by vector.from_elements.
static Attribute convertNumericAttr(Attribute attr, Type expectedType)
Converts numeric attributes to the expected type.
static Value foldExtractFromExtractStrided(ExtractOp extractOp)
Fold an ExtractOp from ExtractStridedSliceOp.
static llvm::SetVector< int64_t > computeBroadcastedUnitDims(ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
Return the dimensions of the result vector that were formerly ones in the source tensor and thus corr...
static Value foldExtractFromBroadcast(ExtractOp extractOp)
Fold extract(broadcast(X)) to either extract(X) or just X.
static LogicalResult foldToElementsFromElements(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.from_elements(e0, e1, ...)) into (e0, e1, ...).
static Attribute foldPoisonSrcExtractOp(Attribute srcAttr)
Fold a vector extract from is a poison source.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp)
static OpFoldResult foldShufflePoisonInputs(MLIRContext *context, Attribute v1Attr, Attribute v2Attr)
Fold shuffle poison, poison -> poison.
static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType)
static Attribute foldPoisonIndexInsertExtractOp(MLIRContext *context, ArrayRef< int64_t > staticPos, int64_t poisonVal)
Fold an insert or extract operation into an poison value when a poison index is found at any dimensio...
MaskFormat
Helper enum to classify mask value.
Definition VectorOps.cpp:64
static ArrayAttr makeI64ArrayAttr(ArrayRef< int64_t > values, MLIRContext *context)
static unsigned getEffectiveVectorRankForXferOp(ShapedType shapedType, VectorType vectorType)
Returns the effective rank of the vector to read/write for Xfer Ops.
static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, ArrayRef< Attribute > elements)
Fold vector.from_elements to a constant when all operands are constants.
static LogicalResult incSlicePosition(MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor, SmallVectorImpl< Value > &operands)
If the dynamic indices of extractOp or insertOp are in fact constants, then fold it.
static LogicalResult foldToElementsOfBroadcast(ToElementsOp toElementsOp, SmallVectorImpl< OpFoldResult > &results)
Folds vector.to_elements(vector.broadcast(x)) for the scalar case only.
static bool isStepIndexArray(ArrayRef< T > idxArr, uint64_t begin, size_t width)
static LogicalResult isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
static bool haveSameDefiningOp(OperandRange operands, Operation *defOp)
Returns true if all the operands are defined by defOp.
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr)
static LogicalResult isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
static bool isSplatWriteConsistentWithMaskedRead(vector::TransferWriteOp write, vector::TransferReadOp read)
Check if write is of a constant splat and the masked read is padded with the same splat value – meani...
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
static Attribute foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, Attribute dstAttr, int64_t maxVectorSizeFoldThreshold)
static LogicalResult foldTransferFullMask(TransferOp op)
static SmallVector< IntType > extractVector(ArrayAttr arrayAttr)
static std::vector< std::pair< int64_t, int64_t > > getDimMap(ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, int64_t maxIndex)
static OpFoldResult foldShuffleConstantInputs(ShuffleOp op, Attribute v1Attr, Attribute v2Attr)
Fold a shuffle of constant 1-D inputs by evaluating the mask.
static OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, Attribute foldInput)
static LogicalResult verifyPermutationMap(AffineMap permutationMap, EmitFun emitOpError)
static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter)
Rewrite vector.from_elements as vector.broadcast if the elements are the same.
static Value foldInsertUseChain(InsertOp insertOp)
Folder to replace the dest operand of the insert op with the root dest of the insert op use chain.
static bool isBroadcastLike(Operation *op)
All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are considered to be 'broadcastlike'.
static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
static Value foldExtractFromShapeCast(ExtractOp extractOp)
static LogicalResult verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
static LogicalResult verifyOutputShape(ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t > > &contractingDimMap, const std::vector< std::pair< int64_t, int64_t > > &batchDimMap)
static bool verifyDimMap(VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t > > &map)
static Type inferStridedSliceOpResultType(VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
static OpFoldResult foldShufflePoisonOperandToMask(ShuffleOp op)
If a shuffle operand is poison, replace all mask indices that reference it with kPoisonIndex.
static LogicalResult foldSize1TransferPermutationMap(TransferOp op)
When the vector type is vector<1xT>, the permutation map is irrelevant: the single vector lane always...
static Value foldExtractFromShuffle(ExtractOp extractOp)
Fold extractOp coming from ShuffleOp.
static LogicalResult foldTransferInBoundsAttribute(TransferOp op)
static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp)
Fold extract_op fed from a chain of insertStridedSlice ops.
static int64_t calculateInsertPosition(VectorType destTy, ArrayRef< int64_t > positions)
static Attribute foldDenseElementsAttrSrcExtractOp(ExtractOp extractOp, Attribute srcAttr)
Fold a vector extract extracting from a DenseElementsAttr.
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
#define mul(a, b)
Rewrite from_elements on multiple scalar extracts as a shape_cast on a single extract.
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
unsigned getNumInputs() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
SmallVector< unsigned > getBroadcastDims() const
Returns the list of broadcast dimensions (i.e.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Base storage class appearing in an attribute.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition Attributes.h:58
bool empty()
Definition Block.h:158
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
iterator begin()
Definition Block.h:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:171
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
IntegerType getI1Type()
Definition Builders.cpp:57
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:286
IndexType getIndexType()
Definition Builders.cpp:55
ArrayAttr getBoolArrayAttr(ArrayRef< bool > values)
Definition Builders.cpp:275
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:323
static unsigned getStorageBitwidth(Type type)
Return the bitwidth that should be used for integer ranges describing type.
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
Definition Dialect.h:83
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printCustomOrGenericOp(Operation *op)=0
Prints the entire operation with the custom assembly form, if available, or the generic assembly form...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:567
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Value getOperand(unsigned idx)
Definition Operation.h:375
void dropAllUses()
Drop all uses of results of this operation.
Definition Operation.h:859
void setOperand(unsigned idx, Value value)
Definition Operation.h:376
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
operand_type_range getOperandTypes()
Definition Operation.h:422
result_type_range getResultTypes()
Definition Operation.h:453
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
result_range getResults()
Definition Operation.h:440
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
T * allocate()
Allocate an instance of the provided type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
bool isF32() const
Definition Types.cpp:40
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
Definition Types.cpp:114
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
static FailureOr< int64_t > computeConstantDelta(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute a constant delta between the given two values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
This is a builder type that keeps local references to arguments.
Builder & setElementType(Type newElementType)
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:114
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
FailureOr< int64_t > fullyComposeAndComputeConstantDelta(Value value1, Value value2)
Compute a constant delta of the given two values.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
detail::poison_attr_matcher m_Poison()
Matches a poison constant (any attribute implementing PoisonAttrInterface).
Definition UBMatchers.h:46
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
std::optional< int64_t > getConstantVscaleMultiplier(Value value)
If value is a constant multiple of vector.vscale (e.g.
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
ConstantMaskKind
Predefined constant_mask kinds.
Definition VectorOps.h:64
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< VectorDim, VectorDim > *mismatchingDims=nullptr)
VectorType inferTransferOpMaskType(VectorType vecType, AffineMap permMap)
Infers the mask type for a transfer op given its vector type and permutation map.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, without requring the...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, bool testDynamicValueUsingBounds=false)
Return true if we can prove that the transfer operations access disjoint memory, requiring the operat...
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
SmallVector< int64_t > getAsIntegers(ArrayRef< Value > values)
Returns the integer numbers in values.
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
SmallVector< Value > getAsValues(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > foldResults)
Convert foldResults into Values.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition VectorOps.h:72
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::function_ref< void(Value, const ConstantIntRanges &)> SetIntRangeFn
The type of the setResultRanges callback provided to ops implementing InferIntRangeInterface.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
StorageUniquer::StorageAllocator AttributeStorageAllocator
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< int64_t > getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront=0, unsigned dropBack=0)
Helper to return a subset of arrayAttr as a vector of int64_t.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineMap compressUnusedDims(AffineMap map)
Drop the dims that are not used.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef< AffineMap > maps)
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Return a fused vector::ContractionOp which represents a patterns such as:
LogicalResult matchAndRewrite(AddOpType addOp, PatternRewriter &rewriter) const override
Canonicalize vector.to_elements(vector.broadcast(v)) where v is a vector.
LogicalResult matchAndRewrite(ToElementsOp toElementsOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static BitmaskEnumStorage * construct(AttributeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const