18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
24#define DEBUG_TYPE "vector-unroll"
34 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
35 return constExpr.getValue() == 0;
40 for (
const auto &dim : llvm::enumerate(permutationMap.
getResults())) {
41 int64_t elementOffset = elementOffsets[dim.index()];
42 if (isBroadcast(dim.value()) || elementOffset == 0)
44 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
49 affine::AffineApplyOp::create(builder, loc, map,
indices[pos]);
61 assert(offsets.size() <= originalIndices.size() &&
62 "Offsets should not exceed the number of original indices");
65 auto start =
indices.size() - offsets.size();
66 for (
auto [i, offset] : llvm::enumerate(offsets)) {
68 indices[start + i] = arith::AddIOp::create(
69 rewriter, loc, originalIndices[start + i],
88static std::optional<SmallVector<int64_t>>
91 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
92 LDBG() <<
"--no filter constraint -> BAIL";
96 "vector unrolling expects the native shape or native"
97 "shape call back function to be set");
98 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
99 if (!unrollableVectorOp) {
100 LDBG() <<
"--not an unrollable op -> BAIL";
103 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
104 if (!maybeUnrollShape) {
105 LDBG() <<
"--could not get shape of op " << *op <<
" -> BAIL";
108 LDBG() <<
"--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
110 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
112 LDBG() <<
"--no unrolling target shape defined " << *op <<
"-> SKIP";
115 LDBG() <<
"--target shape: " << llvm::interleaved(*targetShape);
118 if (!maybeShapeRatio) {
119 LDBG() <<
"--could not compute integral shape ratio -> BAIL";
122 if (llvm::all_of(*maybeShapeRatio, [](
int64_t v) {
return v == 1; })) {
123 LDBG() <<
"--no unrolling needed -> SKIP";
126 LDBG() <<
"--found an integral shape ratio to unroll to -> SUCCESS";
134 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t>(numLoops)));
135 if (
options.traversalOrderCallback !=
nullptr) {
136 std::optional<SmallVector<int64_t>> order =
137 options.traversalOrderCallback(op);
139 loopOrder = std::move(*order);
147struct UnrollTransferReadPattern
149 UnrollTransferReadPattern(MLIRContext *context,
150 const vector::UnrollVectorOptions &options,
151 PatternBenefit benefit = 1)
152 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
155 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
156 PatternRewriter &rewriter)
const override {
158 if (readOp.getTransferRank() == 0)
160 if (readOp.getMask())
165 auto sourceVectorType = readOp.getVectorType();
166 SmallVector<int64_t> strides(targetShape->size(), 1);
167 Location loc = readOp.getLoc();
168 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
172 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
175 VectorType::get(*targetShape, sourceVectorType.getElementType());
176 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
177 readOp.getIndices().end());
178 SmallVector<int64_t> loopOrder =
180 for (SmallVector<int64_t> elementOffsets :
181 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
183 sliceTransferIndices(elementOffsets, originalIndices,
184 readOp.getPermutationMap(), loc, rewriter);
185 auto slicedRead = vector::TransferReadOp::create(
186 rewriter, loc, targetType, readOp.getBase(),
indices,
187 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
188 readOp.getInBoundsAttr());
191 loc, slicedRead,
result, elementOffsets, strides);
198 vector::UnrollVectorOptions options;
201struct UnrollTransferWritePattern
203 UnrollTransferWritePattern(MLIRContext *context,
204 const vector::UnrollVectorOptions &options,
205 PatternBenefit benefit = 1)
206 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
209 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
210 PatternRewriter &rewriter)
const override {
212 if (writeOp.getTransferRank() == 0)
215 if (writeOp.getMask())
220 auto sourceVectorType = writeOp.getVectorType();
221 SmallVector<int64_t> strides(targetShape->size(), 1);
222 Location loc = writeOp.getLoc();
223 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
227 if (originalSize.size() != targetShape->size())
230 "expected source input vector rank to match target shape rank");
232 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
233 writeOp.getIndices().end());
234 SmallVector<int64_t> loopOrder =
237 for (SmallVector<int64_t> elementOffsets :
238 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
239 Value slicedVector = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
240 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
242 sliceTransferIndices(elementOffsets, originalIndices,
243 writeOp.getPermutationMap(), loc, rewriter);
244 Operation *slicedWrite = vector::TransferWriteOp::create(
245 rewriter, loc, slicedVector,
246 resultTensor ? resultTensor : writeOp.getBase(),
indices,
247 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
250 resultTensor = slicedWrite->
getResult(0);
253 rewriter.
replaceOp(writeOp, resultTensor);
260 vector::UnrollVectorOptions options;
263struct OffsetMapInfo {
264 static SmallVector<int64_t> getEmptyKey() {
return {int64_t(-1)}; }
266 static SmallVector<int64_t> getTombstoneKey() {
return {int64_t(-2)}; }
268 static unsigned getHashValue(
const SmallVector<int64_t> &v) {
269 return static_cast<unsigned>(llvm::hash_combine_range(v));
272 static bool isEqual(
const SmallVector<int64_t> &
lhs,
273 const SmallVector<int64_t> &
rhs) {
278struct UnrollContractionPattern
280 UnrollContractionPattern(MLIRContext *context,
281 const vector::UnrollVectorOptions &options,
282 PatternBenefit benefit = 1)
283 : OpRewritePattern<vector::ContractionOp>(context, benefit),
286 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
287 PatternRewriter &rewriter)
const override {
291 auto dstVecType = cast<VectorType>(contractOp.getResultType());
292 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
294 Location loc = contractOp.getLoc();
295 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
296 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
298 SmallVector<int64_t>, Value,
299 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
303 contractOp.getIteratorTypes().size(), contractOp, options);
305 for (SmallVector<int64_t> offsets :
306 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
307 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
310 auto extractOperand = [&](
unsigned index, Value operand,
311 AffineMap permutationMap,
312 ArrayRef<int64_t> operandOffets) {
314 permutationMap, ArrayRef<int64_t>(*targetShape));
315 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
316 slicesOperands[index] =
318 loc, operand, operandOffets, operandShape, operandStrides);
322 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
323 SmallVector<int64_t> lhsOffets =
325 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
328 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
329 SmallVector<int64_t> rhsOffets =
331 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
333 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
334 SmallVector<int64_t> accOffets =
338 auto *accIt = accCache.find(accOffets);
339 if (accIt != accCache.end())
340 slicesOperands[2] = accIt->second;
342 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
344 SmallVector<int64_t> dstShape =
346 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
348 rewriter, loc, contractOp, slicesOperands, targetType);
350 SmallVector<int64_t> dstOffets =
354 accCache[dstOffets] = newOp->
getResult(0);
357 Value
result = arith::ConstantOp::create(rewriter, loc, dstVecType,
359 for (
const auto &it : accCache) {
360 SmallVector<int64_t> dstStrides(it.first.size(), 1);
362 loc, it.second,
result, it.first, dstStrides);
369 vector::UnrollVectorOptions options;
372struct UnrollMultiReductionPattern
374 UnrollMultiReductionPattern(MLIRContext *context,
375 const vector::UnrollVectorOptions &options,
376 PatternBenefit benefit = 1)
377 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
380 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
381 PatternRewriter &rewriter)
const override {
382 std::optional<SmallVector<int64_t>> targetShape =
386 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
387 Location loc = reductionOp.getLoc();
388 auto resultType = reductionOp->getResult(0).getType();
393 if (resultType.isIntOrFloat()) {
394 Value accumulator = reductionOp.getAcc();
395 for (SmallVector<int64_t> offsets :
396 StaticTileOffsetRange(originalSize, *targetShape)) {
397 SmallVector<int64_t> operandStrides(offsets.size(), 1);
398 Value slicedOperand =
400 loc, reductionOp.getSource(), offsets, *targetShape,
403 rewriter, loc, reductionOp, {slicedOperand, accumulator},
407 rewriter.
replaceOp(reductionOp, accumulator);
413 SmallVector<int64_t>, Value,
414 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
419 for (SmallVector<int64_t> offsets :
420 StaticTileOffsetRange(originalSize, *targetShape)) {
421 SmallVector<Value> operands;
422 SmallVector<int64_t> operandStrides(offsets.size(), 1);
423 Value slicedOperand =
425 loc, reductionOp.getSource(), offsets, *targetShape,
427 operands.push_back(slicedOperand);
428 SmallVector<int64_t> dstShape;
429 SmallVector<int64_t> destOffset;
430 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
431 if (!reductionOp.isReducedDim(i)) {
432 destOffset.push_back(offsets[i]);
433 dstShape.push_back((*targetShape)[i]);
437 SmallVector<int64_t> accStrides(destOffset.size(), 1);
440 auto *accIt = accCache.find(destOffset);
441 if (accIt != accCache.end())
444 acc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
445 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
446 operands.push_back(acc);
447 auto targetType = VectorType::get(
448 dstShape, reductionOp.getSourceVectorType().getElementType());
450 operands, targetType);
452 accCache[destOffset] =
result;
455 Value
result = arith::ConstantOp::create(
456 rewriter, loc, reductionOp.getDestType(),
458 for (
const auto &it : accCache) {
459 SmallVector<int64_t> dstStrides(it.first.size(), 1);
461 loc, it.second,
result, it.first, dstStrides);
468 vector::UnrollVectorOptions options;
472 UnrollElementwisePattern(MLIRContext *context,
473 const vector::UnrollVectorOptions &options,
474 PatternBenefit benefit = 1)
475 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
478 LogicalResult matchAndRewrite(Operation *op,
479 PatternRewriter &rewriter)
const override {
485 int64_t targetShapeRank = targetShape->size();
487 SmallVector<int64_t> originalSize =
488 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
489 int64_t originalShapeRank = originalSize.size();
491 Location loc = op->
getLoc();
494 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
495 int64_t rankDiff = originalShapeRank - targetShapeRank;
496 std::fill(adjustedTargetShape.begin(),
497 adjustedTargetShape.begin() + rankDiff, 1);
498 std::copy(targetShape->begin(), targetShape->end(),
499 adjustedTargetShape.begin() + rankDiff);
501 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
503 Value
result = arith::ConstantOp::create(rewriter, loc, dstVecType,
505 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
506 VectorType unrolledVecType =
507 VectorType::get(*targetShape, dstVecType.getElementType());
510 for (SmallVector<int64_t> offsets :
511 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
512 SmallVector<Value> extractOperands;
514 auto vecType = dyn_cast<VectorType>(operand.get().getType());
516 extractOperands.push_back(operand.get());
519 Value extracted = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
520 loc, operand.get(), offsets, adjustedTargetShape, strides);
523 if (adjustedTargetShapeRank > targetShapeRank) {
525 loc, VectorType::get(*targetShape, vecType.getElementType()),
528 extractOperands.push_back(extracted);
532 rewriter, loc, op, extractOperands, unrolledVecType);
534 Value computeResult = newOp->
getResult(0);
537 SmallVector<int64_t> insertStrides =
538 (adjustedTargetShapeRank > targetShapeRank)
539 ? SmallVector<int64_t>(targetShapeRank, 1)
543 loc, computeResult,
result, offsets, insertStrides);
550 vector::UnrollVectorOptions options;
553struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
554 UnrollReductionPattern(MLIRContext *context,
555 const vector::UnrollVectorOptions &options,
556 PatternBenefit benefit = 1)
557 : OpRewritePattern<vector::ReductionOp>(context, benefit),
560 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
561 PatternRewriter &rewriter)
const override {
562 std::optional<SmallVector<int64_t>> targetShape =
566 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
569 Location loc = reductionOp.getLoc();
570 Value accumulator =
nullptr;
571 for (SmallVector<int64_t> offsets :
572 StaticTileOffsetRange(originalSize, *targetShape)) {
573 SmallVector<int64_t> strides(offsets.size(), 1);
574 Value slicedOperand =
576 loc, reductionOp.getVector(), offsets, *targetShape, strides);
578 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
591 rewriter.
replaceOp(reductionOp, accumulator);
596 const vector::UnrollVectorOptions options;
599struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
600 UnrollTransposePattern(MLIRContext *context,
601 const vector::UnrollVectorOptions &options,
602 PatternBenefit benefit = 1)
603 : OpRewritePattern<vector::TransposeOp>(context, benefit),
606 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
607 PatternRewriter &rewriter)
const override {
608 if (transposeOp.getResultVectorType().getRank() == 0)
613 auto originalVectorType = transposeOp.getResultVectorType();
614 SmallVector<int64_t> strides(targetShape->size(), 1);
615 Location loc = transposeOp.getLoc();
616 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
620 arith::ConstantOp::create(rewriter, loc, originalVectorType,
622 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
625 for (SmallVector<int64_t> elementOffsets :
626 StaticTileOffsetRange(originalSize, *targetShape)) {
627 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
628 SmallVector<int64_t> permutedShape(elementOffsets.size());
630 for (
auto indices : llvm::enumerate(permutation)) {
631 permutedOffsets[
indices.value()] = elementOffsets[
indices.index()];
634 Value slicedOperand =
636 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
638 Value transposedSlice = rewriter.
createOrFold<vector::TransposeOp>(
639 loc, slicedOperand, permutation);
641 loc, transposedSlice,
result, elementOffsets, strides);
648 vector::UnrollVectorOptions options;
652 UnrollGatherPattern(MLIRContext *context,
653 const vector::UnrollVectorOptions &options,
654 PatternBenefit benefit = 1)
655 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
658 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
659 PatternRewriter &rewriter)
const override {
660 VectorType sourceVectorType = gatherOp.getVectorType();
661 if (sourceVectorType.getRank() == 0)
666 SmallVector<int64_t> strides(targetShape->size(), 1);
667 Location loc = gatherOp.getLoc();
668 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
672 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
675 VectorType::get(*targetShape, sourceVectorType.getElementType());
677 SmallVector<int64_t> loopOrder =
679 for (SmallVector<int64_t> elementOffsets :
680 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
684 Value indexSubVec = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
685 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
686 Value maskSubVec = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
687 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
688 Value passThruSubVec =
690 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
692 auto slicedGather = vector::GatherOp::create(
693 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
694 indexSubVec, maskSubVec, passThruSubVec);
697 loc, slicedGather,
result, elementOffsets, strides);
704 vector::UnrollVectorOptions options;
708 UnrollLoadPattern(MLIRContext *context,
709 const vector::UnrollVectorOptions &options,
710 PatternBenefit benefit = 1)
711 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
713 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
714 PatternRewriter &rewriter)
const override {
715 VectorType vecType = loadOp.getVectorType();
721 Location loc = loadOp.getLoc();
722 ArrayRef<int64_t> originalShape = vecType.getShape();
723 SmallVector<int64_t> strides(targetShape->size(), 1);
725 Value
result = arith::ConstantOp::create(rewriter, loc, vecType,
728 SmallVector<int64_t> loopOrder =
732 VectorType::get(*targetShape, vecType.getElementType());
734 for (SmallVector<int64_t> offsets :
735 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
738 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
741 loc, slicedLoad,
result, offsets, strides);
748 vector::UnrollVectorOptions options;
752 UnrollStorePattern(MLIRContext *context,
753 const vector::UnrollVectorOptions &options,
754 PatternBenefit benefit = 1)
755 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
757 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
758 PatternRewriter &rewriter)
const override {
759 VectorType vecType = storeOp.getVectorType();
765 Location loc = storeOp.getLoc();
766 ArrayRef<int64_t> originalShape = vecType.getShape();
767 SmallVector<int64_t> strides(targetShape->size(), 1);
769 Value base = storeOp.getBase();
770 Value vector = storeOp.getValueToStore();
772 SmallVector<int64_t> loopOrder =
775 for (SmallVector<int64_t> offsets :
776 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
779 Value slice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
780 loc, vector, offsets, *targetShape, strides);
781 vector::StoreOp::create(rewriter, loc, slice, base,
indices);
788 vector::UnrollVectorOptions options;
791struct UnrollBroadcastPattern :
public OpRewritePattern<vector::BroadcastOp> {
792 UnrollBroadcastPattern(MLIRContext *context,
793 const vector::UnrollVectorOptions &options,
794 PatternBenefit benefit = 1)
795 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
798 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
799 PatternRewriter &rewriter)
const override {
804 Location loc = broadcastOp.getLoc();
805 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
806 VectorType resType = broadcastOp.getResultVectorType();
807 VectorType targetType =
808 resType.cloneWith(*targetShape, resType.getElementType());
809 Value
result = arith::ConstantOp::create(rewriter, loc, resType,
812 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
813 SmallVector<int64_t> strides(originalShape.size(), 1);
815 for (SmallVector<int64_t> offsets :
816 StaticTileOffsetRange(originalShape, *targetShape)) {
820 newSrc = broadcastOp.getSource();
823 int64_t rank = srcType.getRank();
824 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
825 SmallVector<int64_t> srcShape(targetShape->end() - rank,
827 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
829 for (int64_t i = 0; i < rank; ++i) {
830 if (srcType.getDimSize(i) == 1) {
835 newSrc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
836 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
851 vector::UnrollVectorOptions options;
871struct UnrollToElements final :
public OpRewritePattern<vector::ToElementsOp> {
872 UnrollToElements(MLIRContext *context,
873 const vector::UnrollVectorOptions &options,
874 PatternBenefit benefit = 1)
875 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
878 LogicalResult matchAndRewrite(vector::ToElementsOp op,
879 PatternRewriter &rewriter)
const override {
882 FailureOr<SmallVector<Value>>
result =
887 SmallVector<Value> vectors = *
result;
889 SmallVector<Value> results;
890 for (Value vector : vectors) {
892 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
893 llvm::append_range(results, subElements.getResults());
900 vector::UnrollVectorOptions options;
930 UnrollStepPattern(MLIRContext *context,
931 const vector::UnrollVectorOptions &options,
932 PatternBenefit benefit = 1)
933 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
935 LogicalResult matchAndRewrite(vector::StepOp stepOp,
936 PatternRewriter &rewriter)
const override {
937 std::optional<SmallVector<int64_t>> targetShape =
942 VectorType vecType = stepOp.getType();
943 if (vecType.isScalable()) {
947 int64_t originalSize = vecType.getShape()[0];
948 Location loc = stepOp.getLoc();
949 SmallVector<int64_t> strides(1, 1);
951 Value
result = arith::ConstantOp::create(rewriter, loc, vecType,
955 VectorType::get(*targetShape, vecType.getElementType());
956 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
957 for (
const SmallVector<int64_t> &offsets :
958 StaticTileOffsetRange({originalSize}, *targetShape)) {
959 Value bcastOffset = arith::ConstantOp::create(
960 rewriter, loc, targetVecType,
963 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
965 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
968 loc, tileStep,
result, offsets, strides);
975 vector::UnrollVectorOptions options;
996 UnrollFromElements(MLIRContext *context,
997 const vector::UnrollVectorOptions &options,
998 PatternBenefit benefit = 1)
999 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
1002 LogicalResult matchAndRewrite(vector::FromElementsOp op,
1003 PatternRewriter &rewriter)
const override {
1006 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
1007 VectorType subTy, int64_t index) {
1008 size_t subTyNumElements = subTy.getNumElements();
1009 assert((index + 1) * subTyNumElements <= allElements.size() &&
1012 allElements.slice(index * subTyNumElements, subTyNumElements);
1013 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
1020 vector::UnrollVectorOptions options;
1056struct UnrollCreateMaskPattern :
public OpRewritePattern<vector::CreateMaskOp> {
1057 UnrollCreateMaskPattern(MLIRContext *context,
1058 const vector::UnrollVectorOptions &options,
1059 PatternBenefit benefit = 1)
1060 : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1063 LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
1064 PatternRewriter &rewriter)
const override {
1069 VectorType resultType = createMaskOp.getVectorType();
1070 SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
1071 Location loc = createMaskOp.getLoc();
1073 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1075 VectorType targetVectorType =
1076 VectorType::get(*targetShape, rewriter.
getI1Type());
1077 SmallVector<int64_t> strides(targetShape->size(), 1);
1081 for (SmallVector<int64_t> offsets :
1082 StaticTileOffsetRange(originalSize, *targetShape)) {
1083 SmallVector<Value> unrolledOperands;
1085 for (
auto [i, originalMaskOperand] :
1086 llvm::enumerate(createMaskOp.getOperands())) {
1089 Value adjustedMaskSize = rewriter.
createOrFold<arith::SubIOp>(
1090 loc, originalMaskOperand, offsetVal);
1092 Value unrolledDimSize =
1095 rewriter.
createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1096 Value unrolledOperand = rewriter.
createOrFold<arith::MinSIOp>(
1097 loc, nonNegative, unrolledDimSize);
1098 unrolledOperands.push_back(unrolledOperand);
1101 auto unrolledMask = rewriter.
createOrFold<vector::CreateMaskOp>(
1102 loc, targetVectorType, unrolledOperands);
1104 loc, unrolledMask,
result, offsets, strides);
1111 vector::UnrollVectorOptions options;
1146struct UnrollConstantMaskPattern
1148 UnrollConstantMaskPattern(MLIRContext *context,
1149 const vector::UnrollVectorOptions &options,
1150 PatternBenefit benefit = 1)
1151 : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1154 LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
1155 PatternRewriter &rewriter)
const override {
1156 std::optional<SmallVector<int64_t>> targetShape =
1161 VectorType resultType = constantMaskOp.getVectorType();
1162 SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
1163 Location loc = constantMaskOp.getLoc();
1165 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1167 VectorType targetVectorType =
1168 VectorType::get(*targetShape, rewriter.
getI1Type());
1169 SmallVector<int64_t> strides(targetShape->size(), 1);
1173 for (
const SmallVector<int64_t> &offsets :
1174 StaticTileOffsetRange(originalSize, *targetShape)) {
1175 SmallVector<int64_t> unrolledMaskDims;
1177 for (
auto [i, originalMaskDim] :
1178 llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
1181 int64_t adjustedMaskSize =
1182 std::max(originalMaskDim - offsets[i],
static_cast<int64_t
>(0));
1183 int64_t unrolledMaskDim =
1184 std::min(adjustedMaskSize,
static_cast<int64_t
>((*targetShape)[i]));
1185 unrolledMaskDims.push_back(unrolledMaskDim);
1188 auto unrolledMask = rewriter.
createOrFold<vector::ConstantMaskOp>(
1189 loc, targetVectorType, unrolledMaskDims);
1191 loc, unrolledMask,
result, offsets, strides);
1198 vector::UnrollVectorOptions options;
1216 if (extractShape.empty() ||
shape.empty() ||
1217 extractShape.size() >
shape.size())
1220 while (extractShape.size() > 1 && extractShape.front() == 1)
1221 extractShape = extractShape.drop_front();
1223 while (
shape.size() > 1 &&
shape.front() == 1) {
1227 size_t rankDiff =
shape.size() - extractShape.size();
1228 if (!llvm::equal(extractShape.drop_front(),
shape.drop_front(rankDiff + 1)))
1231 int64_t extractElements = ShapedType::getNumElements(extractShape);
1232 int64_t shapeElements = ShapedType::getNumElements(
shape);
1233 return shapeElements % extractElements == 0;
1255static std::optional<SmallVector<int64_t>>
1259 int64_t remainingElements = targetElements;
1262 for (
int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1263 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1264 extractShape.insert(extractShape.begin(), takeFromDim);
1266 if (remainingElements % takeFromDim != 0)
1267 return std::nullopt;
1268 remainingElements /= takeFromDim;
1272 while (extractShape.size() < sourceShape.size())
1273 extractShape.insert(extractShape.begin(), 1);
1275 if (ShapedType::getNumElements(extractShape) != targetElements)
1276 return std::nullopt;
1278 return extractShape;
1323struct UnrollShapeCastPattern :
public OpRewritePattern<vector::ShapeCastOp> {
1324 UnrollShapeCastPattern(MLIRContext *context,
1325 const vector::UnrollVectorOptions &options,
1326 PatternBenefit benefit = 1)
1327 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1330 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1331 PatternRewriter &rewriter)
const override {
1332 std::optional<SmallVector<int64_t>> targetShape =
1337 VectorType sourceType = shapeCastOp.getSourceVectorType();
1338 VectorType resultType = shapeCastOp.getResultVectorType();
1339 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1340 ArrayRef<int64_t> resultShape = resultType.getShape();
1342 if (!isContiguous(*targetShape, resultShape))
1344 shapeCastOp,
"Only supports cases where target shape is "
1345 "contiguous in result vector shape");
1347 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1350 std::optional<SmallVector<int64_t>> extractShape =
1351 calculateSourceExtractShape(sourceShape, targetElements);
1355 "cannot extract target number of elements contiguously from source");
1357 Location loc = shapeCastOp.getLoc();
1360 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1363 VectorType targetType =
1364 VectorType::get(*targetShape, sourceType.getElementType());
1367 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1369 for (SmallVector<int64_t> resultOffsets :
1370 StaticTileOffsetRange(resultShape, *targetShape)) {
1371 SmallVector<int64_t> sourceOffsets =
1372 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1373 Value sourceChunk = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1374 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1376 Value targetChunk = rewriter.
createOrFold<vector::ShapeCastOp>(
1377 loc, targetType, sourceChunk);
1379 loc, targetChunk,
result, resultOffsets, insertStrides);
1387 vector::UnrollVectorOptions options;
1405 UnrollBitCastPattern(MLIRContext *context,
1406 const vector::UnrollVectorOptions &options,
1407 PatternBenefit benefit = 1)
1408 : OpRewritePattern<vector::BitCastOp>(context, benefit),
1411 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
1412 PatternRewriter &rewriter)
const override {
1416 "failed to get target shape");
1418 VectorType sourceType = bitCastOp.getSourceVectorType();
1419 VectorType resultType = bitCastOp.getResultVectorType();
1420 ArrayRef<int64_t> resultShape = resultType.getShape();
1421 Location loc = bitCastOp.getLoc();
1423 if (targetShape->size() != resultShape.size())
1425 bitCastOp,
"target shape rank must match result rank");
1427 unsigned sourceElementBits = sourceType.getElementTypeBitWidth();
1428 unsigned resultElementBits = resultType.getElementTypeBitWidth();
1430 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1431 targetShape->end());
1432 int64_t lastDim = sourceSliceShape.size() - 1;
1434 sourceSliceShape[lastDim] =
1435 ((*targetShape)[lastDim] * resultElementBits) / sourceElementBits;
1437 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1439 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1440 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1442 VectorType targetType =
1443 VectorType::get(*targetShape, resultType.getElementType());
1445 for (SmallVector<int64_t> resultOffsets :
1446 StaticTileOffsetRange(resultShape, *targetShape)) {
1447 SmallVector<int64_t> sourceOffsets = resultOffsets;
1448 sourceOffsets[lastDim] =
1449 (resultOffsets[lastDim] * resultElementBits) / sourceElementBits;
1451 Value sourceSlice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1452 loc, bitCastOp.getSource(), sourceOffsets, sourceSliceShape,
1454 Value bitcastSlice = rewriter.
createOrFold<vector::BitCastOp>(
1455 loc, targetType, sourceSlice);
1457 loc, bitcastSlice,
result, resultOffsets, resultStrides);
1465 vector::UnrollVectorOptions options;
1485struct UnrollInterleavePattern :
public OpRewritePattern<vector::InterleaveOp> {
1486 UnrollInterleavePattern(MLIRContext *context,
1487 const vector::UnrollVectorOptions &options,
1488 PatternBenefit benefit = 1)
1489 : OpRewritePattern<vector::InterleaveOp>(context, benefit),
1492 LogicalResult matchAndRewrite(vector::InterleaveOp interleaveOp,
1493 PatternRewriter &rewriter)
const override {
1497 "failed to get target shape");
1499 VectorType resultType = interleaveOp.getResultVectorType();
1500 ArrayRef<int64_t> resultShape = resultType.getShape();
1501 Location loc = interleaveOp.getLoc();
1503 if (targetShape->size() != resultShape.size())
1505 interleaveOp,
"target shape rank must match result rank");
1507 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1508 targetShape->end());
1509 int64_t lastDim = sourceSliceShape.size() - 1;
1510 sourceSliceShape[lastDim] = (*targetShape)[lastDim] / 2;
1512 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1514 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1515 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1517 VectorType targetType =
1518 VectorType::get(*targetShape, resultType.getElementType());
1520 for (SmallVector<int64_t> resultOffsets :
1521 StaticTileOffsetRange(resultShape, *targetShape)) {
1522 SmallVector<int64_t> sourceOffsets = resultOffsets;
1523 sourceOffsets[lastDim] = resultOffsets[lastDim] / 2;
1525 Value lhsSlice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1526 loc, interleaveOp.getLhs(), sourceOffsets, sourceSliceShape,
1528 Value rhsSlice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1529 loc, interleaveOp.getRhs(), sourceOffsets, sourceSliceShape,
1531 Value interleaveSlice = rewriter.
createOrFold<vector::InterleaveOp>(
1532 loc, targetType, lhsSlice, rhsSlice);
1534 loc, interleaveSlice,
result, resultOffsets, resultStrides);
1542 vector::UnrollVectorOptions options;
1563struct UnrollDeinterleavePattern
1565 UnrollDeinterleavePattern(MLIRContext *context,
1566 const vector::UnrollVectorOptions &options,
1567 PatternBenefit benefit = 1)
1568 : OpRewritePattern<vector::DeinterleaveOp>(context, benefit),
1571 LogicalResult matchAndRewrite(vector::DeinterleaveOp deinterleaveOp,
1572 PatternRewriter &rewriter)
const override {
1576 "failed to get target shape");
1578 VectorType resultType = deinterleaveOp.getResultVectorType();
1579 ArrayRef<int64_t> resultShape = resultType.getShape();
1580 Location loc = deinterleaveOp.getLoc();
1582 if (targetShape->size() != resultShape.size())
1584 deinterleaveOp,
"target shape rank must match result rank");
1586 SmallVector<int64_t> sourceSliceShape(targetShape->begin(),
1587 targetShape->end());
1588 int64_t lastDim = sourceSliceShape.size() - 1;
1589 sourceSliceShape[lastDim] = (*targetShape)[lastDim] * 2;
1591 Value resultOdd = arith::ConstantOp::create(
1592 rewriter, loc, resultType, rewriter.
getZeroAttr(resultType));
1593 Value resultEven = arith::ConstantOp::create(
1594 rewriter, loc, resultType, rewriter.
getZeroAttr(resultType));
1595 SmallVector<int64_t> resultStrides(targetShape->size(), 1);
1596 SmallVector<int64_t> sourceStrides(sourceSliceShape.size(), 1);
1598 for (SmallVector<int64_t> resultOffsets :
1599 StaticTileOffsetRange(resultShape, *targetShape)) {
1600 SmallVector<int64_t> sourceOffsets = resultOffsets;
1601 sourceOffsets[lastDim] = resultOffsets[lastDim] * 2;
1603 Value sourceSlice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1604 loc, deinterleaveOp.getSource(), sourceOffsets, sourceSliceShape,
1607 auto deinterleaveSlice =
1608 vector::DeinterleaveOp::create(rewriter, loc, sourceSlice);
1610 resultOdd = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
1611 loc, deinterleaveSlice.getRes1(), resultOdd, resultOffsets,
1613 resultEven = rewriter.
createOrFold<vector::InsertStridedSliceOp>(
1614 loc, deinterleaveSlice.getRes2(), resultEven, resultOffsets,
1623 vector::UnrollVectorOptions options;
1628void mlir::vector::populateVectorUnrollPatterns(
1631 patterns.
add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1632 UnrollContractionPattern, UnrollElementwisePattern,
1633 UnrollReductionPattern, UnrollMultiReductionPattern,
1634 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1635 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1636 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1637 UnrollCreateMaskPattern, UnrollConstantMaskPattern,
1638 UnrollBitCastPattern, UnrollInterleavePattern,
1643void mlir::vector::populateVectorToElementsUnrollPatterns(
1649void mlir::vector::populateVectorFromElementsUnrollPatterns(
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
OperationName getName()
The name of an operation is the key identifier for it.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.