18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
24#define DEBUG_TYPE "vector-unroll"
37 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38 return constExpr.getValue() == 0;
43 for (
const auto &dim : llvm::enumerate(permutationMap.
getResults())) {
44 if (isBroadcast(dim.value()))
46 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
51 affine::AffineApplyOp::create(builder, loc, map,
indices[pos]);
63 assert(offsets.size() <= originalIndices.size() &&
64 "Offsets should not exceed the number of original indices");
67 auto start =
indices.size() - offsets.size();
68 for (
auto [i, offset] : llvm::enumerate(offsets)) {
70 indices[start + i] = arith::AddIOp::create(
71 rewriter, loc, originalIndices[start + i],
90static std::optional<SmallVector<int64_t>>
93 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
94 LDBG() <<
"--no filter constraint -> BAIL";
98 "vector unrolling expects the native shape or native"
99 "shape call back function to be set");
100 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101 if (!unrollableVectorOp) {
102 LDBG() <<
"--not an unrollable op -> BAIL";
105 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106 if (!maybeUnrollShape) {
107 LDBG() <<
"--could not get shape of op " << *op <<
" -> BAIL";
110 LDBG() <<
"--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
112 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
114 LDBG() <<
"--no unrolling target shape defined " << *op <<
"-> SKIP";
117 LDBG() <<
"--target shape: " << llvm::interleaved(*targetShape);
120 if (!maybeShapeRatio) {
121 LDBG() <<
"--could not compute integral shape ratio -> BAIL";
124 if (llvm::all_of(*maybeShapeRatio, [](
int64_t v) {
return v == 1; })) {
125 LDBG() <<
"--no unrolling needed -> SKIP";
128 LDBG() <<
"--found an integral shape ratio to unroll to -> SUCCESS";
136 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t>(numLoops)));
137 if (
options.traversalOrderCallback !=
nullptr) {
138 std::optional<SmallVector<int64_t>> order =
139 options.traversalOrderCallback(op);
141 loopOrder = std::move(*order);
149struct UnrollTransferReadPattern
151 UnrollTransferReadPattern(MLIRContext *context,
152 const vector::UnrollVectorOptions &options,
153 PatternBenefit benefit = 1)
154 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
157 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
158 PatternRewriter &rewriter)
const override {
160 if (readOp.getTransferRank() == 0)
162 if (readOp.getMask())
167 auto sourceVectorType = readOp.getVectorType();
168 SmallVector<int64_t> strides(targetShape->size(), 1);
169 Location loc = readOp.getLoc();
170 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
174 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
177 VectorType::get(*targetShape, sourceVectorType.getElementType());
178 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
179 readOp.getIndices().end());
180 SmallVector<int64_t> loopOrder =
182 for (SmallVector<int64_t> elementOffsets :
183 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
186 readOp.getPermutationMap(), loc, rewriter);
187 auto slicedRead = vector::TransferReadOp::create(
188 rewriter, loc, targetType, readOp.getBase(),
indices,
189 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190 readOp.getInBoundsAttr());
193 loc, slicedRead,
result, elementOffsets, strides);
200 vector::UnrollVectorOptions options;
203struct UnrollTransferWritePattern
205 UnrollTransferWritePattern(MLIRContext *context,
206 const vector::UnrollVectorOptions &options,
207 PatternBenefit benefit = 1)
208 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
211 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
212 PatternRewriter &rewriter)
const override {
214 if (writeOp.getTransferRank() == 0)
217 if (writeOp.getMask())
222 auto sourceVectorType = writeOp.getVectorType();
223 SmallVector<int64_t> strides(targetShape->size(), 1);
224 Location loc = writeOp.getLoc();
225 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
229 if (originalSize.size() != targetShape->size())
232 "expected source input vector rank to match target shape rank");
234 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
235 writeOp.getIndices().end());
236 SmallVector<int64_t> loopOrder =
239 for (SmallVector<int64_t> elementOffsets :
240 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
241 Value slicedVector = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
242 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
245 writeOp.getPermutationMap(), loc, rewriter);
246 Operation *slicedWrite = vector::TransferWriteOp::create(
247 rewriter, loc, slicedVector,
248 resultTensor ? resultTensor : writeOp.getBase(),
indices,
249 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
252 resultTensor = slicedWrite->
getResult(0);
255 rewriter.
replaceOp(writeOp, resultTensor);
262 vector::UnrollVectorOptions options;
265struct OffsetMapInfo {
266 static SmallVector<int64_t> getEmptyKey() {
return {int64_t(-1)}; }
268 static SmallVector<int64_t> getTombstoneKey() {
return {int64_t(-2)}; }
270 static unsigned getHashValue(
const SmallVector<int64_t> &v) {
271 return static_cast<unsigned>(llvm::hash_combine_range(v));
274 static bool isEqual(
const SmallVector<int64_t> &
lhs,
275 const SmallVector<int64_t> &
rhs) {
280struct UnrollContractionPattern
282 UnrollContractionPattern(MLIRContext *context,
283 const vector::UnrollVectorOptions &options,
284 PatternBenefit benefit = 1)
285 : OpRewritePattern<vector::ContractionOp>(context, benefit),
288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289 PatternRewriter &rewriter)
const override {
293 auto dstVecType = cast<VectorType>(contractOp.getResultType());
294 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
296 Location loc = contractOp.getLoc();
297 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
300 SmallVector<int64_t>, Value,
301 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
305 contractOp.getIteratorTypes().size(), contractOp, options);
307 for (SmallVector<int64_t> offsets :
308 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
309 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
312 auto extractOperand = [&](
unsigned index, Value operand,
313 AffineMap permutationMap,
314 ArrayRef<int64_t> operandOffets) {
316 permutationMap, ArrayRef<int64_t>(*targetShape));
317 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
318 slicesOperands[index] =
320 loc, operand, operandOffets, operandShape, operandStrides);
324 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
325 SmallVector<int64_t> lhsOffets =
327 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
330 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
331 SmallVector<int64_t> rhsOffets =
333 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
335 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
336 SmallVector<int64_t> accOffets =
340 auto *accIt = accCache.find(accOffets);
341 if (accIt != accCache.end())
342 slicesOperands[2] = accIt->second;
344 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
346 SmallVector<int64_t> dstShape =
348 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
350 rewriter, loc, contractOp, slicesOperands, targetType);
352 SmallVector<int64_t> dstOffets =
356 accCache[dstOffets] = newOp->
getResult(0);
359 Value
result = arith::ConstantOp::create(rewriter, loc, dstVecType,
361 for (
const auto &it : accCache) {
362 SmallVector<int64_t> dstStrides(it.first.size(), 1);
364 loc, it.second,
result, it.first, dstStrides);
371 vector::UnrollVectorOptions options;
374struct UnrollMultiReductionPattern
376 UnrollMultiReductionPattern(MLIRContext *context,
377 const vector::UnrollVectorOptions &options,
378 PatternBenefit benefit = 1)
379 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
382 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
383 PatternRewriter &rewriter)
const override {
384 std::optional<SmallVector<int64_t>> targetShape =
388 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
389 Location loc = reductionOp.getLoc();
390 auto resultType = reductionOp->getResult(0).getType();
395 if (resultType.isIntOrFloat()) {
396 Value accumulator = reductionOp.getAcc();
397 for (SmallVector<int64_t> offsets :
398 StaticTileOffsetRange(originalSize, *targetShape)) {
399 SmallVector<int64_t> operandStrides(offsets.size(), 1);
400 Value slicedOperand =
402 loc, reductionOp.getSource(), offsets, *targetShape,
405 rewriter, loc, reductionOp, {slicedOperand, accumulator},
409 rewriter.
replaceOp(reductionOp, accumulator);
415 SmallVector<int64_t>, Value,
416 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
421 for (SmallVector<int64_t> offsets :
422 StaticTileOffsetRange(originalSize, *targetShape)) {
423 SmallVector<Value> operands;
424 SmallVector<int64_t> operandStrides(offsets.size(), 1);
425 Value slicedOperand =
427 loc, reductionOp.getSource(), offsets, *targetShape,
429 operands.push_back(slicedOperand);
430 SmallVector<int64_t> dstShape;
431 SmallVector<int64_t> destOffset;
432 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
433 if (!reductionOp.isReducedDim(i)) {
434 destOffset.push_back(offsets[i]);
435 dstShape.push_back((*targetShape)[i]);
439 SmallVector<int64_t> accStrides(destOffset.size(), 1);
442 auto *accIt = accCache.find(destOffset);
443 if (accIt != accCache.end())
446 acc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
447 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
448 operands.push_back(acc);
449 auto targetType = VectorType::get(
450 dstShape, reductionOp.getSourceVectorType().getElementType());
452 operands, targetType);
454 accCache[destOffset] =
result;
457 Value
result = arith::ConstantOp::create(
458 rewriter, loc, reductionOp.getDestType(),
460 for (
const auto &it : accCache) {
461 SmallVector<int64_t> dstStrides(it.first.size(), 1);
463 loc, it.second,
result, it.first, dstStrides);
470 vector::UnrollVectorOptions options;
474 UnrollElementwisePattern(MLIRContext *context,
475 const vector::UnrollVectorOptions &options,
476 PatternBenefit benefit = 1)
477 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
480 LogicalResult matchAndRewrite(Operation *op,
481 PatternRewriter &rewriter)
const override {
487 int64_t targetShapeRank = targetShape->size();
489 SmallVector<int64_t> originalSize =
490 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
491 int64_t originalShapeRank = originalSize.size();
493 Location loc = op->
getLoc();
496 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
497 int64_t rankDiff = originalShapeRank - targetShapeRank;
498 std::fill(adjustedTargetShape.begin(),
499 adjustedTargetShape.begin() + rankDiff, 1);
500 std::copy(targetShape->begin(), targetShape->end(),
501 adjustedTargetShape.begin() + rankDiff);
503 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
505 Value
result = arith::ConstantOp::create(rewriter, loc, dstVecType,
507 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
508 VectorType unrolledVecType =
509 VectorType::get(*targetShape, dstVecType.getElementType());
512 for (SmallVector<int64_t> offsets :
513 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
514 SmallVector<Value> extractOperands;
516 auto vecType = dyn_cast<VectorType>(operand.get().getType());
518 extractOperands.push_back(operand.get());
521 Value extracted = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
522 loc, operand.get(), offsets, adjustedTargetShape, strides);
525 if (adjustedTargetShapeRank > targetShapeRank) {
527 loc, VectorType::get(*targetShape, vecType.getElementType()),
530 extractOperands.push_back(extracted);
534 rewriter, loc, op, extractOperands, unrolledVecType);
536 Value computeResult = newOp->
getResult(0);
539 SmallVector<int64_t> insertStrides =
540 (adjustedTargetShapeRank > targetShapeRank)
541 ? SmallVector<int64_t>(targetShapeRank, 1)
545 loc, computeResult,
result, offsets, insertStrides);
552 vector::UnrollVectorOptions options;
555struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
556 UnrollReductionPattern(MLIRContext *context,
557 const vector::UnrollVectorOptions &options,
558 PatternBenefit benefit = 1)
559 : OpRewritePattern<vector::ReductionOp>(context, benefit),
562 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
563 PatternRewriter &rewriter)
const override {
564 std::optional<SmallVector<int64_t>> targetShape =
568 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
571 Location loc = reductionOp.getLoc();
572 Value accumulator =
nullptr;
573 for (SmallVector<int64_t> offsets :
574 StaticTileOffsetRange(originalSize, *targetShape)) {
575 SmallVector<int64_t> strides(offsets.size(), 1);
576 Value slicedOperand =
578 loc, reductionOp.getVector(), offsets, *targetShape, strides);
580 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
593 rewriter.
replaceOp(reductionOp, accumulator);
598 const vector::UnrollVectorOptions options;
601struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
602 UnrollTransposePattern(MLIRContext *context,
603 const vector::UnrollVectorOptions &options,
604 PatternBenefit benefit = 1)
605 : OpRewritePattern<vector::TransposeOp>(context, benefit),
608 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
609 PatternRewriter &rewriter)
const override {
610 if (transposeOp.getResultVectorType().getRank() == 0)
615 auto originalVectorType = transposeOp.getResultVectorType();
616 SmallVector<int64_t> strides(targetShape->size(), 1);
617 Location loc = transposeOp.getLoc();
618 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
622 arith::ConstantOp::create(rewriter, loc, originalVectorType,
624 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
627 for (SmallVector<int64_t> elementOffsets :
628 StaticTileOffsetRange(originalSize, *targetShape)) {
629 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
630 SmallVector<int64_t> permutedShape(elementOffsets.size());
632 for (
auto indices : llvm::enumerate(permutation)) {
633 permutedOffsets[
indices.value()] = elementOffsets[
indices.index()];
636 Value slicedOperand =
638 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
640 Value transposedSlice = rewriter.
createOrFold<vector::TransposeOp>(
641 loc, slicedOperand, permutation);
643 loc, transposedSlice,
result, elementOffsets, strides);
650 vector::UnrollVectorOptions options;
654 UnrollGatherPattern(MLIRContext *context,
655 const vector::UnrollVectorOptions &options,
656 PatternBenefit benefit = 1)
657 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
660 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
661 PatternRewriter &rewriter)
const override {
662 VectorType sourceVectorType = gatherOp.getVectorType();
663 if (sourceVectorType.getRank() == 0)
668 SmallVector<int64_t> strides(targetShape->size(), 1);
669 Location loc = gatherOp.getLoc();
670 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
674 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
677 VectorType::get(*targetShape, sourceVectorType.getElementType());
679 SmallVector<int64_t> loopOrder =
681 for (SmallVector<int64_t> elementOffsets :
682 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
686 Value indexSubVec = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
687 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
688 Value maskSubVec = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
689 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
690 Value passThruSubVec =
692 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
694 auto slicedGather = vector::GatherOp::create(
695 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
696 indexSubVec, maskSubVec, passThruSubVec);
699 loc, slicedGather,
result, elementOffsets, strides);
706 vector::UnrollVectorOptions options;
710 UnrollLoadPattern(MLIRContext *context,
711 const vector::UnrollVectorOptions &options,
712 PatternBenefit benefit = 1)
713 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
715 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
716 PatternRewriter &rewriter)
const override {
717 VectorType vecType = loadOp.getVectorType();
723 Location loc = loadOp.getLoc();
724 ArrayRef<int64_t> originalShape = vecType.getShape();
725 SmallVector<int64_t> strides(targetShape->size(), 1);
727 Value
result = arith::ConstantOp::create(rewriter, loc, vecType,
730 SmallVector<int64_t> loopOrder =
734 VectorType::get(*targetShape, vecType.getElementType());
736 for (SmallVector<int64_t> offsets :
737 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
740 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
743 loc, slicedLoad,
result, offsets, strides);
750 vector::UnrollVectorOptions options;
754 UnrollStorePattern(MLIRContext *context,
755 const vector::UnrollVectorOptions &options,
756 PatternBenefit benefit = 1)
757 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
759 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
760 PatternRewriter &rewriter)
const override {
761 VectorType vecType = storeOp.getVectorType();
767 Location loc = storeOp.getLoc();
768 ArrayRef<int64_t> originalShape = vecType.getShape();
769 SmallVector<int64_t> strides(targetShape->size(), 1);
771 Value base = storeOp.getBase();
772 Value vector = storeOp.getValueToStore();
774 SmallVector<int64_t> loopOrder =
777 for (SmallVector<int64_t> offsets :
778 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
781 Value slice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
782 loc, vector, offsets, *targetShape, strides);
783 vector::StoreOp::create(rewriter, loc, slice, base,
indices);
790 vector::UnrollVectorOptions options;
793struct UnrollBroadcastPattern :
public OpRewritePattern<vector::BroadcastOp> {
794 UnrollBroadcastPattern(MLIRContext *context,
795 const vector::UnrollVectorOptions &options,
796 PatternBenefit benefit = 1)
797 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
800 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
801 PatternRewriter &rewriter)
const override {
806 Location loc = broadcastOp.getLoc();
807 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
808 VectorType resType = broadcastOp.getResultVectorType();
809 VectorType targetType =
810 resType.cloneWith(*targetShape, resType.getElementType());
811 Value
result = arith::ConstantOp::create(rewriter, loc, resType,
814 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
815 SmallVector<int64_t> strides(originalShape.size(), 1);
817 for (SmallVector<int64_t> offsets :
818 StaticTileOffsetRange(originalShape, *targetShape)) {
822 newSrc = broadcastOp.getSource();
825 int64_t rank = srcType.getRank();
826 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
827 SmallVector<int64_t> srcShape(targetShape->end() - rank,
829 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
831 for (int64_t i = 0; i < rank; ++i) {
832 if (srcType.getDimSize(i) == 1) {
837 newSrc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
838 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
853 vector::UnrollVectorOptions options;
873struct UnrollToElements final :
public OpRewritePattern<vector::ToElementsOp> {
874 UnrollToElements(MLIRContext *context,
875 const vector::UnrollVectorOptions &options,
876 PatternBenefit benefit = 1)
877 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
880 LogicalResult matchAndRewrite(vector::ToElementsOp op,
881 PatternRewriter &rewriter)
const override {
884 FailureOr<SmallVector<Value>>
result =
889 SmallVector<Value> vectors = *
result;
891 SmallVector<Value> results;
892 for (Value vector : vectors) {
894 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
895 llvm::append_range(results, subElements.getResults());
902 vector::UnrollVectorOptions options;
932 UnrollStepPattern(MLIRContext *context,
933 const vector::UnrollVectorOptions &options,
934 PatternBenefit benefit = 1)
935 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
937 LogicalResult matchAndRewrite(vector::StepOp stepOp,
938 PatternRewriter &rewriter)
const override {
939 std::optional<SmallVector<int64_t>> targetShape =
944 VectorType vecType = stepOp.getType();
945 if (vecType.isScalable()) {
949 int64_t originalSize = vecType.getShape()[0];
950 Location loc = stepOp.getLoc();
951 SmallVector<int64_t> strides(1, 1);
953 Value
result = arith::ConstantOp::create(rewriter, loc, vecType,
957 VectorType::get(*targetShape, vecType.getElementType());
958 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
959 for (
const SmallVector<int64_t> &offsets :
960 StaticTileOffsetRange({originalSize}, *targetShape)) {
961 Value bcastOffset = arith::ConstantOp::create(
962 rewriter, loc, targetVecType,
965 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
967 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
970 loc, tileStep,
result, offsets, strides);
977 vector::UnrollVectorOptions options;
998 UnrollFromElements(MLIRContext *context,
999 const vector::UnrollVectorOptions &options,
1000 PatternBenefit benefit = 1)
1001 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
1004 LogicalResult matchAndRewrite(vector::FromElementsOp op,
1005 PatternRewriter &rewriter)
const override {
1008 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
1009 VectorType subTy, int64_t index) {
1010 size_t subTyNumElements = subTy.getNumElements();
1011 assert((index + 1) * subTyNumElements <= allElements.size() &&
1014 allElements.slice(index * subTyNumElements, subTyNumElements);
1015 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
1022 vector::UnrollVectorOptions options;
1058struct UnrollCreateMaskPattern :
public OpRewritePattern<vector::CreateMaskOp> {
1059 UnrollCreateMaskPattern(MLIRContext *context,
1060 const vector::UnrollVectorOptions &options,
1061 PatternBenefit benefit = 1)
1062 : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1065 LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
1066 PatternRewriter &rewriter)
const override {
1071 VectorType resultType = createMaskOp.getVectorType();
1072 SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
1073 Location loc = createMaskOp.getLoc();
1075 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1077 VectorType targetVectorType =
1078 VectorType::get(*targetShape, rewriter.
getI1Type());
1079 SmallVector<int64_t> strides(targetShape->size(), 1);
1083 for (SmallVector<int64_t> offsets :
1084 StaticTileOffsetRange(originalSize, *targetShape)) {
1085 SmallVector<Value> unrolledOperands;
1087 for (
auto [i, originalMaskOperand] :
1088 llvm::enumerate(createMaskOp.getOperands())) {
1091 Value adjustedMaskSize = rewriter.
createOrFold<arith::SubIOp>(
1092 loc, originalMaskOperand, offsetVal);
1094 Value unrolledDimSize =
1097 rewriter.
createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1098 Value unrolledOperand = rewriter.
createOrFold<arith::MinSIOp>(
1099 loc, nonNegative, unrolledDimSize);
1100 unrolledOperands.push_back(unrolledOperand);
1103 auto unrolledMask = rewriter.
createOrFold<vector::CreateMaskOp>(
1104 loc, targetVectorType, unrolledOperands);
1106 loc, unrolledMask,
result, offsets, strides);
1113 vector::UnrollVectorOptions options;
1148struct UnrollConstantMaskPattern
1150 UnrollConstantMaskPattern(MLIRContext *context,
1151 const vector::UnrollVectorOptions &options,
1152 PatternBenefit benefit = 1)
1153 : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1156 LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
1157 PatternRewriter &rewriter)
const override {
1158 std::optional<SmallVector<int64_t>> targetShape =
1163 VectorType resultType = constantMaskOp.getVectorType();
1164 SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
1165 Location loc = constantMaskOp.getLoc();
1167 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1169 VectorType targetVectorType =
1170 VectorType::get(*targetShape, rewriter.
getI1Type());
1171 SmallVector<int64_t> strides(targetShape->size(), 1);
1175 for (
const SmallVector<int64_t> &offsets :
1176 StaticTileOffsetRange(originalSize, *targetShape)) {
1177 SmallVector<int64_t> unrolledMaskDims;
1179 for (
auto [i, originalMaskDim] :
1180 llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
1183 int64_t adjustedMaskSize =
1184 std::max(originalMaskDim - offsets[i],
static_cast<int64_t
>(0));
1185 int64_t unrolledMaskDim =
1186 std::min(adjustedMaskSize,
static_cast<int64_t
>((*targetShape)[i]));
1187 unrolledMaskDims.push_back(unrolledMaskDim);
1190 auto unrolledMask = rewriter.
createOrFold<vector::ConstantMaskOp>(
1191 loc, targetVectorType, unrolledMaskDims);
1193 loc, unrolledMask,
result, offsets, strides);
1200 vector::UnrollVectorOptions options;
1218 if (extractShape.empty() ||
shape.empty() ||
1219 extractShape.size() >
shape.size())
1222 while (extractShape.size() > 1 && extractShape.front() == 1)
1223 extractShape = extractShape.drop_front();
1225 while (
shape.size() > 1 &&
shape.front() == 1) {
1229 size_t rankDiff =
shape.size() - extractShape.size();
1230 if (!llvm::equal(extractShape.drop_front(),
shape.drop_front(rankDiff + 1)))
1233 int64_t extractElements = ShapedType::getNumElements(extractShape);
1234 int64_t shapeElements = ShapedType::getNumElements(
shape);
1235 return shapeElements % extractElements == 0;
1257static std::optional<SmallVector<int64_t>>
1261 int64_t remainingElements = targetElements;
1264 for (
int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1265 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1266 extractShape.insert(extractShape.begin(), takeFromDim);
1268 if (remainingElements % takeFromDim != 0)
1269 return std::nullopt;
1270 remainingElements /= takeFromDim;
1274 while (extractShape.size() < sourceShape.size())
1275 extractShape.insert(extractShape.begin(), 1);
1277 if (ShapedType::getNumElements(extractShape) != targetElements)
1278 return std::nullopt;
1280 return extractShape;
1325struct UnrollShapeCastPattern :
public OpRewritePattern<vector::ShapeCastOp> {
1326 UnrollShapeCastPattern(MLIRContext *context,
1327 const vector::UnrollVectorOptions &options,
1328 PatternBenefit benefit = 1)
1329 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1332 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1333 PatternRewriter &rewriter)
const override {
1334 std::optional<SmallVector<int64_t>> targetShape =
1339 VectorType sourceType = shapeCastOp.getSourceVectorType();
1340 VectorType resultType = shapeCastOp.getResultVectorType();
1341 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1342 ArrayRef<int64_t> resultShape = resultType.getShape();
1344 if (!isContiguous(*targetShape, resultShape))
1346 shapeCastOp,
"Only supports cases where target shape is "
1347 "contiguous in result vector shape");
1349 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1352 std::optional<SmallVector<int64_t>> extractShape =
1353 calculateSourceExtractShape(sourceShape, targetElements);
1357 "cannot extract target number of elements contiguously from source");
1359 Location loc = shapeCastOp.getLoc();
1362 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1365 VectorType targetType =
1366 VectorType::get(*targetShape, sourceType.getElementType());
1369 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1371 for (SmallVector<int64_t> resultOffsets :
1372 StaticTileOffsetRange(resultShape, *targetShape)) {
1373 SmallVector<int64_t> sourceOffsets =
1374 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1375 Value sourceChunk = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1376 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1378 Value targetChunk = rewriter.
createOrFold<vector::ShapeCastOp>(
1379 loc, targetType, sourceChunk);
1381 loc, targetChunk,
result, resultOffsets, insertStrides);
1389 vector::UnrollVectorOptions options;
1394void mlir::vector::populateVectorUnrollPatterns(
1397 patterns.
add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1398 UnrollContractionPattern, UnrollElementwisePattern,
1399 UnrollReductionPattern, UnrollMultiReductionPattern,
1400 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1401 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1402 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1403 UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
1407void mlir::vector::populateVectorToElementsUnrollPatterns(
1413void mlir::vector::populateVectorFromElementsUnrollPatterns(
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
OperationName getName()
The name of an operation is the key identifier for it.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.