18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
24#define DEBUG_TYPE "vector-unroll"
34 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
35 return constExpr.getValue() == 0;
40 for (
const auto &dim : llvm::enumerate(permutationMap.
getResults())) {
41 int64_t elementOffset = elementOffsets[dim.index()];
42 if (isBroadcast(dim.value()) || elementOffset == 0)
44 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
49 affine::AffineApplyOp::create(builder, loc, map,
indices[pos]);
61 assert(offsets.size() <= originalIndices.size() &&
62 "Offsets should not exceed the number of original indices");
65 auto start =
indices.size() - offsets.size();
66 for (
auto [i, offset] : llvm::enumerate(offsets)) {
68 indices[start + i] = arith::AddIOp::create(
69 rewriter, loc, originalIndices[start + i],
88static std::optional<SmallVector<int64_t>>
91 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
92 LDBG() <<
"--no filter constraint -> BAIL";
96 "vector unrolling expects the native shape or native"
97 "shape call back function to be set");
98 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
99 if (!unrollableVectorOp) {
100 LDBG() <<
"--not an unrollable op -> BAIL";
103 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
104 if (!maybeUnrollShape) {
105 LDBG() <<
"--could not get shape of op " << *op <<
" -> BAIL";
108 LDBG() <<
"--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
110 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
112 LDBG() <<
"--no unrolling target shape defined " << *op <<
"-> SKIP";
115 LDBG() <<
"--target shape: " << llvm::interleaved(*targetShape);
118 if (!maybeShapeRatio) {
119 LDBG() <<
"--could not compute integral shape ratio -> BAIL";
122 if (llvm::all_of(*maybeShapeRatio, [](
int64_t v) {
return v == 1; })) {
123 LDBG() <<
"--no unrolling needed -> SKIP";
126 LDBG() <<
"--found an integral shape ratio to unroll to -> SUCCESS";
134 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t>(numLoops)));
135 if (
options.traversalOrderCallback !=
nullptr) {
136 std::optional<SmallVector<int64_t>> order =
137 options.traversalOrderCallback(op);
139 loopOrder = std::move(*order);
147struct UnrollTransferReadPattern
149 UnrollTransferReadPattern(MLIRContext *context,
150 const vector::UnrollVectorOptions &options,
151 PatternBenefit benefit = 1)
152 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
155 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
156 PatternRewriter &rewriter)
const override {
158 if (readOp.getTransferRank() == 0)
160 if (readOp.getMask())
165 auto sourceVectorType = readOp.getVectorType();
166 SmallVector<int64_t> strides(targetShape->size(), 1);
167 Location loc = readOp.getLoc();
168 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
172 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
175 VectorType::get(*targetShape, sourceVectorType.getElementType());
176 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
177 readOp.getIndices().end());
178 SmallVector<int64_t> loopOrder =
180 for (SmallVector<int64_t> elementOffsets :
181 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
183 sliceTransferIndices(elementOffsets, originalIndices,
184 readOp.getPermutationMap(), loc, rewriter);
185 auto slicedRead = vector::TransferReadOp::create(
186 rewriter, loc, targetType, readOp.getBase(),
indices,
187 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
188 readOp.getInBoundsAttr());
191 loc, slicedRead,
result, elementOffsets, strides);
198 vector::UnrollVectorOptions options;
201struct UnrollTransferWritePattern
203 UnrollTransferWritePattern(MLIRContext *context,
204 const vector::UnrollVectorOptions &options,
205 PatternBenefit benefit = 1)
206 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
209 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
210 PatternRewriter &rewriter)
const override {
212 if (writeOp.getTransferRank() == 0)
215 if (writeOp.getMask())
220 auto sourceVectorType = writeOp.getVectorType();
221 SmallVector<int64_t> strides(targetShape->size(), 1);
222 Location loc = writeOp.getLoc();
223 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
227 if (originalSize.size() != targetShape->size())
230 "expected source input vector rank to match target shape rank");
232 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
233 writeOp.getIndices().end());
234 SmallVector<int64_t> loopOrder =
237 for (SmallVector<int64_t> elementOffsets :
238 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
239 Value slicedVector = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
240 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
242 sliceTransferIndices(elementOffsets, originalIndices,
243 writeOp.getPermutationMap(), loc, rewriter);
244 Operation *slicedWrite = vector::TransferWriteOp::create(
245 rewriter, loc, slicedVector,
246 resultTensor ? resultTensor : writeOp.getBase(),
indices,
247 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
250 resultTensor = slicedWrite->
getResult(0);
253 rewriter.
replaceOp(writeOp, resultTensor);
260 vector::UnrollVectorOptions options;
263struct OffsetMapInfo {
264 static unsigned getHashValue(
const SmallVector<int64_t> &v) {
265 return static_cast<unsigned>(llvm::hash_combine_range(v));
268 static bool isEqual(
const SmallVector<int64_t> &
lhs,
269 const SmallVector<int64_t> &
rhs) {
274struct UnrollContractionPattern
276 UnrollContractionPattern(MLIRContext *context,
277 const vector::UnrollVectorOptions &options,
278 PatternBenefit benefit = 1)
279 : OpRewritePattern<vector::ContractionOp>(context, benefit),
282 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
283 PatternRewriter &rewriter)
const override {
287 auto dstVecType = cast<VectorType>(contractOp.getResultType());
288 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
290 Location loc = contractOp.getLoc();
291 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
292 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
294 SmallVector<int64_t>, Value,
295 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
299 contractOp.getIteratorTypes().size(), contractOp, options);
301 for (SmallVector<int64_t> offsets :
302 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
303 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
306 auto extractOperand = [&](
unsigned index, Value operand,
307 AffineMap permutationMap,
308 ArrayRef<int64_t> operandOffets) {
310 permutationMap, ArrayRef<int64_t>(*targetShape));
311 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
312 slicesOperands[index] =
314 loc, operand, operandOffets, operandShape, operandStrides);
318 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
319 SmallVector<int64_t> lhsOffets =
321 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
324 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
325 SmallVector<int64_t> rhsOffets =
327 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
329 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
330 SmallVector<int64_t> accOffets =
334 auto *accIt = accCache.find(accOffets);
335 if (accIt != accCache.end())
336 slicesOperands[2] = accIt->second;
338 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
340 SmallVector<int64_t> dstShape =
342 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
344 rewriter, loc, contractOp, slicesOperands, targetType);
346 SmallVector<int64_t> dstOffets =
350 accCache[dstOffets] = newOp->
getResult(0);
353 Value
result = arith::ConstantOp::create(rewriter, loc, dstVecType,
355 for (
const auto &it : accCache) {
356 SmallVector<int64_t> dstStrides(it.first.size(), 1);
358 loc, it.second,
result, it.first, dstStrides);
365 vector::UnrollVectorOptions options;
368struct UnrollMultiReductionPattern
370 UnrollMultiReductionPattern(MLIRContext *context,
371 const vector::UnrollVectorOptions &options,
372 PatternBenefit benefit = 1)
373 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
376 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
377 PatternRewriter &rewriter)
const override {
378 std::optional<SmallVector<int64_t>> targetShape =
382 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
383 Location loc = reductionOp.getLoc();
384 auto resultType = reductionOp->getResult(0).getType();
389 if (resultType.isIntOrFloat()) {
390 Value accumulator = reductionOp.getAcc();
391 for (SmallVector<int64_t> offsets :
392 StaticTileOffsetRange(originalSize, *targetShape)) {
393 SmallVector<int64_t> operandStrides(offsets.size(), 1);
394 Value slicedOperand =
396 loc, reductionOp.getSource(), offsets, *targetShape,
399 rewriter, loc, reductionOp, {slicedOperand, accumulator},
403 rewriter.
replaceOp(reductionOp, accumulator);
409 SmallVector<int64_t>, Value,
410 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
415 for (SmallVector<int64_t> offsets :
416 StaticTileOffsetRange(originalSize, *targetShape)) {
417 SmallVector<Value> operands;
418 SmallVector<int64_t> operandStrides(offsets.size(), 1);
419 Value slicedOperand =
421 loc, reductionOp.getSource(), offsets, *targetShape,
423 operands.push_back(slicedOperand);
424 SmallVector<int64_t> dstShape;
425 SmallVector<int64_t> destOffset;
426 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
427 if (!reductionOp.isReducedDim(i)) {
428 destOffset.push_back(offsets[i]);
429 dstShape.push_back((*targetShape)[i]);
433 SmallVector<int64_t> accStrides(destOffset.size(), 1);
436 auto *accIt = accCache.find(destOffset);
437 if (accIt != accCache.end())
440 acc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
441 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
442 operands.push_back(acc);
443 auto targetType = VectorType::get(
444 dstShape, reductionOp.getSourceVectorType().getElementType());
446 operands, targetType);
448 accCache[destOffset] =
result;
451 Value
result = arith::ConstantOp::create(
452 rewriter, loc, reductionOp.getDestType(),
454 for (
const auto &it : accCache) {
455 SmallVector<int64_t> dstStrides(it.first.size(), 1);
457 loc, it.second,
result, it.first, dstStrides);
464 vector::UnrollVectorOptions options;
468 UnrollElementwisePattern(MLIRContext *context,
469 const vector::UnrollVectorOptions &options,
470 PatternBenefit benefit = 1)
471 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
474 LogicalResult matchAndRewrite(Operation *op,
475 PatternRewriter &rewriter)
const override {
481 int64_t targetShapeRank = targetShape->size();
483 SmallVector<int64_t> originalSize =
484 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
485 int64_t originalShapeRank = originalSize.size();
487 Location loc = op->
getLoc();
490 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
491 int64_t rankDiff = originalShapeRank - targetShapeRank;
492 std::fill(adjustedTargetShape.begin(),
493 adjustedTargetShape.begin() + rankDiff, 1);
494 std::copy(targetShape->begin(), targetShape->end(),
495 adjustedTargetShape.begin() + rankDiff);
497 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
499 Value
result = arith::ConstantOp::create(rewriter, loc, dstVecType,
501 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
502 VectorType unrolledVecType =
503 VectorType::get(*targetShape, dstVecType.getElementType());
506 for (SmallVector<int64_t> offsets :
507 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
508 SmallVector<Value> extractOperands;
510 auto vecType = dyn_cast<VectorType>(operand.get().getType());
512 extractOperands.push_back(operand.get());
515 Value extracted = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
516 loc, operand.get(), offsets, adjustedTargetShape, strides);
519 if (adjustedTargetShapeRank > targetShapeRank) {
521 loc, VectorType::get(*targetShape, vecType.getElementType()),
524 extractOperands.push_back(extracted);
528 rewriter, loc, op, extractOperands, unrolledVecType);
530 Value computeResult = newOp->
getResult(0);
533 SmallVector<int64_t> insertStrides =
534 (adjustedTargetShapeRank > targetShapeRank)
535 ? SmallVector<int64_t>(targetShapeRank, 1)
539 loc, computeResult,
result, offsets, insertStrides);
546 vector::UnrollVectorOptions options;
549struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
550 UnrollReductionPattern(MLIRContext *context,
551 const vector::UnrollVectorOptions &options,
552 PatternBenefit benefit = 1)
553 : OpRewritePattern<vector::ReductionOp>(context, benefit),
556 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
557 PatternRewriter &rewriter)
const override {
558 std::optional<SmallVector<int64_t>> targetShape =
562 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
565 Location loc = reductionOp.getLoc();
566 Value accumulator =
nullptr;
567 for (SmallVector<int64_t> offsets :
568 StaticTileOffsetRange(originalSize, *targetShape)) {
569 SmallVector<int64_t> strides(offsets.size(), 1);
570 Value slicedOperand =
572 loc, reductionOp.getVector(), offsets, *targetShape, strides);
574 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
587 rewriter.
replaceOp(reductionOp, accumulator);
592 const vector::UnrollVectorOptions options;
595struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
596 UnrollTransposePattern(MLIRContext *context,
597 const vector::UnrollVectorOptions &options,
598 PatternBenefit benefit = 1)
599 : OpRewritePattern<vector::TransposeOp>(context, benefit),
602 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
603 PatternRewriter &rewriter)
const override {
604 if (transposeOp.getResultVectorType().getRank() == 0)
609 auto originalVectorType = transposeOp.getResultVectorType();
610 SmallVector<int64_t> strides(targetShape->size(), 1);
611 Location loc = transposeOp.getLoc();
612 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
616 arith::ConstantOp::create(rewriter, loc, originalVectorType,
618 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
621 for (SmallVector<int64_t> elementOffsets :
622 StaticTileOffsetRange(originalSize, *targetShape)) {
623 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
624 SmallVector<int64_t> permutedShape(elementOffsets.size());
626 for (
auto indices : llvm::enumerate(permutation)) {
627 permutedOffsets[
indices.value()] = elementOffsets[
indices.index()];
630 Value slicedOperand =
632 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
634 Value transposedSlice = rewriter.
createOrFold<vector::TransposeOp>(
635 loc, slicedOperand, permutation);
637 loc, transposedSlice,
result, elementOffsets, strides);
644 vector::UnrollVectorOptions options;
648 UnrollGatherPattern(MLIRContext *context,
649 const vector::UnrollVectorOptions &options,
650 PatternBenefit benefit = 1)
651 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
654 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
655 PatternRewriter &rewriter)
const override {
656 VectorType sourceVectorType = gatherOp.getVectorType();
657 if (sourceVectorType.getRank() == 0)
662 SmallVector<int64_t> strides(targetShape->size(), 1);
663 Location loc = gatherOp.getLoc();
664 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
668 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
671 VectorType::get(*targetShape, sourceVectorType.getElementType());
673 SmallVector<int64_t> loopOrder =
675 for (SmallVector<int64_t> elementOffsets :
676 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
680 Value indexSubVec = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
681 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
682 Value maskSubVec = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
683 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
684 Value passThruSubVec =
686 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
688 auto slicedGather = vector::GatherOp::create(
689 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
690 indexSubVec, maskSubVec, passThruSubVec);
693 loc, slicedGather,
result, elementOffsets, strides);
700 vector::UnrollVectorOptions options;
704 UnrollLoadPattern(MLIRContext *context,
705 const vector::UnrollVectorOptions &options,
706 PatternBenefit benefit = 1)
707 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
709 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
710 PatternRewriter &rewriter)
const override {
711 VectorType vecType = loadOp.getVectorType();
717 Location loc = loadOp.getLoc();
718 ArrayRef<int64_t> originalShape = vecType.getShape();
719 SmallVector<int64_t> strides(targetShape->size(), 1);
721 Value
result = arith::ConstantOp::create(rewriter, loc, vecType,
724 SmallVector<int64_t> loopOrder =
728 VectorType::get(*targetShape, vecType.getElementType());
730 for (SmallVector<int64_t> offsets :
731 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
734 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
737 loc, slicedLoad,
result, offsets, strides);
744 vector::UnrollVectorOptions options;
748 UnrollStorePattern(MLIRContext *context,
749 const vector::UnrollVectorOptions &options,
750 PatternBenefit benefit = 1)
751 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
753 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
754 PatternRewriter &rewriter)
const override {
755 VectorType vecType = storeOp.getVectorType();
761 Location loc = storeOp.getLoc();
762 ArrayRef<int64_t> originalShape = vecType.getShape();
763 SmallVector<int64_t> strides(targetShape->size(), 1);
765 Value base = storeOp.getBase();
766 Value vector = storeOp.getValueToStore();
768 SmallVector<int64_t> loopOrder =
771 for (SmallVector<int64_t> offsets :
772 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
775 Value slice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
776 loc, vector, offsets, *targetShape, strides);
777 vector::StoreOp::create(rewriter, loc, slice, base,
indices);
784 vector::UnrollVectorOptions options;
787struct UnrollBroadcastPattern :
public OpRewritePattern<vector::BroadcastOp> {
788 UnrollBroadcastPattern(MLIRContext *context,
789 const vector::UnrollVectorOptions &options,
790 PatternBenefit benefit = 1)
791 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
794 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
795 PatternRewriter &rewriter)
const override {
800 Location loc = broadcastOp.getLoc();
801 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
802 VectorType resType = broadcastOp.getResultVectorType();
803 VectorType targetType =
804 resType.cloneWith(*targetShape, resType.getElementType());
805 Value
result = arith::ConstantOp::create(rewriter, loc, resType,
808 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
809 SmallVector<int64_t> strides(originalShape.size(), 1);
811 for (SmallVector<int64_t> offsets :
812 StaticTileOffsetRange(originalShape, *targetShape)) {
816 newSrc = broadcastOp.getSource();
819 int64_t rank = srcType.getRank();
820 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
821 SmallVector<int64_t> srcShape(targetShape->end() - rank,
823 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
825 for (int64_t i = 0; i < rank; ++i) {
826 if (srcType.getDimSize(i) == 1) {
831 newSrc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
832 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
847 vector::UnrollVectorOptions options;
867struct UnrollToElements final :
public OpRewritePattern<vector::ToElementsOp> {
868 UnrollToElements(MLIRContext *context,
869 const vector::UnrollVectorOptions &options,
870 PatternBenefit benefit = 1)
871 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
874 LogicalResult matchAndRewrite(vector::ToElementsOp op,
875 PatternRewriter &rewriter)
const override {
878 FailureOr<SmallVector<Value>>
result =
883 SmallVector<Value> vectors = *
result;
885 SmallVector<Value> results;
886 for (Value vector : vectors) {
888 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
889 llvm::append_range(results, subElements.getResults());
896 vector::UnrollVectorOptions options;
926 UnrollStepPattern(MLIRContext *context,
927 const vector::UnrollVectorOptions &options,
928 PatternBenefit benefit = 1)
929 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
931 LogicalResult matchAndRewrite(vector::StepOp stepOp,
932 PatternRewriter &rewriter)
const override {
933 std::optional<SmallVector<int64_t>> targetShape =
938 VectorType vecType = stepOp.getType();
939 if (vecType.isScalable()) {
943 int64_t originalSize = vecType.getShape()[0];
944 Location loc = stepOp.getLoc();
945 SmallVector<int64_t> strides(1, 1);
947 Value
result = arith::ConstantOp::create(rewriter, loc, vecType,
951 VectorType::get(*targetShape, vecType.getElementType());
952 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
953 for (
const SmallVector<int64_t> &offsets :
954 StaticTileOffsetRange({originalSize}, *targetShape)) {
955 Value bcastOffset = arith::ConstantOp::create(
956 rewriter, loc, targetVecType,
959 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
961 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
964 loc, tileStep,
result, offsets, strides);
971 vector::UnrollVectorOptions options;
992 UnrollFromElements(MLIRContext *context,
993 const vector::UnrollVectorOptions &options,
994 PatternBenefit benefit = 1)
995 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
998 LogicalResult matchAndRewrite(vector::FromElementsOp op,
999 PatternRewriter &rewriter)
const override {
1002 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
1003 VectorType subTy, int64_t index) {
1004 size_t subTyNumElements = subTy.getNumElements();
1005 assert((index + 1) * subTyNumElements <= allElements.size() &&
1008 allElements.slice(index * subTyNumElements, subTyNumElements);
1009 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
1016 vector::UnrollVectorOptions options;
1052struct UnrollCreateMaskPattern :
public OpRewritePattern<vector::CreateMaskOp> {
1053 UnrollCreateMaskPattern(MLIRContext *context,
1054 const vector::UnrollVectorOptions &options,
1055 PatternBenefit benefit = 1)
1056 : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1059 LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
1060 PatternRewriter &rewriter)
const override {
1065 VectorType resultType = createMaskOp.getVectorType();
1066 SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
1067 Location loc = createMaskOp.getLoc();
1069 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1071 VectorType targetVectorType =
1072 VectorType::get(*targetShape, rewriter.
getI1Type());
1073 SmallVector<int64_t> strides(targetShape->size(), 1);
1077 for (SmallVector<int64_t> offsets :
1078 StaticTileOffsetRange(originalSize, *targetShape)) {
1079 SmallVector<Value> unrolledOperands;
1081 for (
auto [i, originalMaskOperand] :
1082 llvm::enumerate(createMaskOp.getOperands())) {
1085 Value adjustedMaskSize = rewriter.
createOrFold<arith::SubIOp>(
1086 loc, originalMaskOperand, offsetVal);
1088 Value unrolledDimSize =
1091 rewriter.
createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1092 Value unrolledOperand = rewriter.
createOrFold<arith::MinSIOp>(
1093 loc, nonNegative, unrolledDimSize);
1094 unrolledOperands.push_back(unrolledOperand);
1097 auto unrolledMask = rewriter.
createOrFold<vector::CreateMaskOp>(
1098 loc, targetVectorType, unrolledOperands);
1100 loc, unrolledMask,
result, offsets, strides);
1107 vector::UnrollVectorOptions options;
1142struct UnrollConstantMaskPattern
1144 UnrollConstantMaskPattern(MLIRContext *context,
1145 const vector::UnrollVectorOptions &options,
1146 PatternBenefit benefit = 1)
1147 : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1150 LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
1151 PatternRewriter &rewriter)
const override {
1152 std::optional<SmallVector<int64_t>> targetShape =
1157 VectorType resultType = constantMaskOp.getVectorType();
1158 SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
1159 Location loc = constantMaskOp.getLoc();
1161 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1163 VectorType targetVectorType =
1164 VectorType::get(*targetShape, rewriter.
getI1Type());
1165 SmallVector<int64_t> strides(targetShape->size(), 1);
1169 for (
const SmallVector<int64_t> &offsets :
1170 StaticTileOffsetRange(originalSize, *targetShape)) {
1171 SmallVector<int64_t> unrolledMaskDims;
1173 for (
auto [i, originalMaskDim] :
1174 llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
1177 int64_t adjustedMaskSize =
1178 std::max(originalMaskDim - offsets[i],
static_cast<int64_t
>(0));
1179 int64_t unrolledMaskDim =
1180 std::min(adjustedMaskSize,
static_cast<int64_t
>((*targetShape)[i]));
1181 unrolledMaskDims.push_back(unrolledMaskDim);
1184 auto unrolledMask = rewriter.
createOrFold<vector::ConstantMaskOp>(
1185 loc, targetVectorType, unrolledMaskDims);
1187 loc, unrolledMask,
result, offsets, strides);
1194 vector::UnrollVectorOptions options;
1212 if (extractShape.empty() ||
shape.empty() ||
1213 extractShape.size() >
shape.size())
1216 while (extractShape.size() > 1 && extractShape.front() == 1)
1217 extractShape = extractShape.drop_front();
1219 while (
shape.size() > 1 &&
shape.front() == 1) {
1223 size_t rankDiff =
shape.size() - extractShape.size();
1224 if (!llvm::equal(extractShape.drop_front(),
shape.drop_front(rankDiff + 1)))
1227 int64_t extractElements = ShapedType::getNumElements(extractShape);
1228 int64_t shapeElements = ShapedType::getNumElements(
shape);
1229 return shapeElements % extractElements == 0;
1251static std::optional<SmallVector<int64_t>>
1255 int64_t remainingElements = targetElements;
1258 for (
int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1259 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1260 extractShape.insert(extractShape.begin(), takeFromDim);
1262 if (remainingElements % takeFromDim != 0)
1263 return std::nullopt;
1264 remainingElements /= takeFromDim;
1268 while (extractShape.size() < sourceShape.size())
1269 extractShape.insert(extractShape.begin(), 1);
1271 if (ShapedType::getNumElements(extractShape) != targetElements)
1272 return std::nullopt;
1274 return extractShape;
1319struct UnrollShapeCastPattern :
public OpRewritePattern<vector::ShapeCastOp> {
1320 UnrollShapeCastPattern(MLIRContext *context,
1321 const vector::UnrollVectorOptions &options,
1322 PatternBenefit benefit = 1)
1323 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1326 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1327 PatternRewriter &rewriter)
const override {
1328 std::optional<SmallVector<int64_t>> targetShape =
1333 VectorType sourceType = shapeCastOp.getSourceVectorType();
1334 VectorType resultType = shapeCastOp.getResultVectorType();
1335 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1336 ArrayRef<int64_t> resultShape = resultType.getShape();
1338 if (!isContiguous(*targetShape, resultShape))
1340 shapeCastOp,
"Only supports cases where target shape is "
1341 "contiguous in result vector shape");
1343 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1346 std::optional<SmallVector<int64_t>> extractShape =
1347 calculateSourceExtractShape(sourceShape, targetElements);
1351 "cannot extract target number of elements contiguously from source");
1353 Location loc = shapeCastOp.getLoc();
1356 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1359 VectorType targetType =
1360 VectorType::get(*targetShape, sourceType.getElementType());
1363 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1365 for (SmallVector<int64_t> resultOffsets :
1366 StaticTileOffsetRange(resultShape, *targetShape)) {
1367 SmallVector<int64_t> sourceOffsets =
1368 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1369 Value sourceChunk = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1370 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1372 Value targetChunk = rewriter.
createOrFold<vector::ShapeCastOp>(
1373 loc, targetType, sourceChunk);
1375 loc, targetChunk,
result, resultOffsets, insertStrides);
1383 vector::UnrollVectorOptions options;
1401 UnrollBitCastPattern(MLIRContext *context,
1402 const vector::UnrollVectorOptions &options,
1403 PatternBenefit benefit = 1)
1404 : OpRewritePattern<vector::BitCastOp>(context, benefit),
1407 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1408 PatternRewriter &rewriter)
const override {
1412 "failed to get target shape");
1414 VectorType sourceType = bitCastOp.getSourceVectorType();
1415 VectorType resultType = bitCastOp.getResultVectorType();
1416 ArrayRef<int64_t> resultShape = resultType.getShape();
1417 Location loc = bitCastOp.getLoc();
1419 if (targetShape->size() != resultShape.size())
1421 bitCastOp,
"target shape rank must match result rank");
1423 unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
1424 unsigned resultElementBits = resultType.getElementTypeBitWidth();
1426 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1427 targetShape->end());
1428 int64_t lastDim = sourceSliceShape.size() - 1;
1430 sourceSliceShape[lastDim] =
1431 ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
1433 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1435 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1436 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1438 VectorType targetType =
1439 VectorType::get(*targetShape, resultType.getElementType());
1441 for (SmallVector<int64_t> resultOffsets :
1442 StaticTileOffsetRange(resultShape, *targetShape)) {
1443 SmallVector<int64_t> sourceOffsets = resultOffsets;
1444 sourceOffsets[lastDim] =
1445 (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
1447 Value sourceSlice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1448 loc, bitCastOp.getSource(), sourceOffsets, sourceSliceShape,
1450 Value bitcastSlice = rewriter.
createOrFold<vector::BitCastOp>(
1451 loc, targetType, sourceSlice);
1453 loc, bitcastSlice,
result, resultOffsets, resultStrides);
1461 vector::UnrollVectorOptions options;
1481struct UnrollInterleavePattern :
public OpRewritePattern<vector::InterleaveOp> {
1482 UnrollInterleavePattern(MLIRContext *context,
1483 const vector::UnrollVectorOptions &options,
1484 PatternBenefit benefit = 1)
1485 : OpRewritePattern<vector::InterleaveOp>(context, benefit),
1488 LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp,
1489 PatternRewriter &rewriter)
const override {
1493 "failed to get target shape");
1495 VectorType resultType = interleaveOp.getResultVectorType();
1496 ArrayRef<int64_t> resultShape = resultType.getShape();
1497 Location loc = interleaveOp.getLoc();
1499 if (targetShape->size() != resultShape.size())
1501 interleaveOp,
"target shape rank must match result rank");
1503 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1504 targetShape->end());
1505 int64_t lastDim = sourceSliceShape.size() - 1;
1506 sourceSliceShape[lastDim] = (*targetShape)[lastDim] / 2;
1508 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1510 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1511 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1513 VectorType targetType =
1514 VectorType::get(*targetShape, resultType.getElementType());
1516 for (SmallVector<int64_t> resultOffsets :
1517 StaticTileOffsetRange(resultShape, *targetShape)) {
1518 SmallVector<int64_t> sourceOffsets = resultOffsets;
1519 sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
1521 Value lhsSlice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1522 loc, interleaveOp.getLhs(), sourceOffsets, sourceSliceShape,
1524 Value rhsSlice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1525 loc, interleaveOp.getRhs(), sourceOffsets, sourceSliceShape,
1527 Value interleaveSlice = rewriter.
createOrFold<vector::InterleaveOp>(
1528 loc, targetType, lhsSlice, rhsSlice);
1530 loc, interleaveSlice,
result, resultOffsets, resultStrides);
1538 vector::UnrollVectorOptions options;
1559struct UnrollDeinterleavePattern
1561 UnrollDeinterleavePattern(MLIRContext *context,
1562 const vector::UnrollVectorOptions &options,
1563 PatternBenefit benefit = 1)
1564 : OpRewritePattern<vector::DeinterleaveOp>(context, benefit),
1567 LogicalResult matchAndRewrite(vector::DeinterleaveOp deinterleaveOp,
1568 PatternRewriter &rewriter)
const override {
1572 "failed to get target shape");
1574 VectorType resultType = deinterleaveOp.getResultVectorType();
1575 ArrayRef<int64_t> resultShape = resultType.getShape();
1576 Location loc = deinterleaveOp.getLoc();
1578 if (targetShape->size() != resultShape.size())
1580 deinterleaveOp,
"target shape rank must match result rank");
1582 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1583 targetShape->end());
1584 int64_t lastDim = sourceSliceShape.size() - 1;
1585 sourceSliceShape[lastDim] = (*targetShape)[lastDim] * 2;
1587 Value resultOdd = arith::ConstantOp::create(
1588 rewriter, loc, resultType, rewriter.
getZeroAttr(resultType));
1589 Value resultEven = arith::ConstantOp::create(
1590 rewriter, loc, resultType, rewriter.
getZeroAttr(resultType));
1591 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1592 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1594 for (SmallVector<int64_t> resultOffsets :
1595 StaticTileOffsetRange(resultShape, *targetShape)) {
1596 SmallVector<int64_t> sourceOffsets = resultOffsets;
1597 sourceOffsets[lastDim] = resultOffsets[lastDim] * 2;
1599 Value sourceSlice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1600 loc, deinterleaveOp.getSource(), sourceOffsets, sourceSliceShape,
1603 auto deinterleaveSlice =
1604 vector::DeinterleaveOp::create(rewriter, loc, sourceSlice);
1606 resultOdd = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
1607 loc, deinterleaveSlice.getRes1(), resultOdd, resultOffsets,
1609 resultEven = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
1610 loc, deinterleaveSlice.getRes2(), resultEven, resultOffsets,
1619 vector::UnrollVectorOptions options;
1624void mlir::vector::populateVectorUnrollPatterns(
1627 patterns.
add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1628 UnrollContractionPattern, UnrollElementwisePattern,
1629 UnrollReductionPattern, UnrollMultiReductionPattern,
1630 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1631 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1632 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1633 UnrollCreateMaskPattern, UnrollConstantMaskPattern,
1634 UnrollBitCastPattern, UnrollInterleavePattern,
1639void mlir::vector::populateVectorToElementsUnrollPatterns(
1645void mlir::vector::populateVectorFromElementsUnrollPatterns(
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
OperationName getName()
The name of an operation is the key identifier for it.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.