18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
24#define DEBUG_TYPE "vector-unroll"
37 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38 return constExpr.getValue() == 0;
43 for (
const auto &dim : llvm::enumerate(permutationMap.
getResults())) {
44 if (isBroadcast(dim.value()))
46 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
51 affine::AffineApplyOp::create(builder, loc, map,
indices[pos]);
63 assert(offsets.size() <= originalIndices.size() &&
64 "Offsets should not exceed the number of original indices");
67 auto start =
indices.size() - offsets.size();
68 for (
auto [i, offset] : llvm::enumerate(offsets)) {
70 indices[start + i] = arith::AddIOp::create(
71 rewriter, loc, originalIndices[start + i],
90static std::optional<SmallVector<int64_t>>
93 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
94 LDBG() <<
"--no filter constraint -> BAIL";
98 "vector unrolling expects the native shape or native"
99 "shape call back function to be set");
100 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101 if (!unrollableVectorOp) {
102 LDBG() <<
"--not an unrollable op -> BAIL";
105 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106 if (!maybeUnrollShape) {
107 LDBG() <<
"--could not get shape of op " << *op <<
" -> BAIL";
110 LDBG() <<
"--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
112 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
114 LDBG() <<
"--no unrolling target shape defined " << *op <<
"-> SKIP";
117 LDBG() <<
"--target shape: " << llvm::interleaved(*targetShape);
120 if (!maybeShapeRatio) {
121 LDBG() <<
"--could not compute integral shape ratio -> BAIL";
124 if (llvm::all_of(*maybeShapeRatio, [](
int64_t v) {
return v == 1; })) {
125 LDBG() <<
"--no unrolling needed -> SKIP";
128 LDBG() <<
"--found an integral shape ratio to unroll to -> SUCCESS";
136 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t>(numLoops)));
137 if (
options.traversalOrderCallback !=
nullptr) {
138 std::optional<SmallVector<int64_t>> order =
139 options.traversalOrderCallback(op);
141 loopOrder = std::move(*order);
149struct UnrollTransferReadPattern
151 UnrollTransferReadPattern(MLIRContext *context,
152 const vector::UnrollVectorOptions &options,
153 PatternBenefit benefit = 1)
154 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
157 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
158 PatternRewriter &rewriter)
const override {
160 if (readOp.getTransferRank() == 0)
162 if (readOp.getMask())
167 auto sourceVectorType = readOp.getVectorType();
168 SmallVector<int64_t> strides(targetShape->size(), 1);
169 Location loc = readOp.getLoc();
170 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
174 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
177 VectorType::get(*targetShape, sourceVectorType.getElementType());
178 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
179 readOp.getIndices().end());
180 SmallVector<int64_t> loopOrder =
182 for (SmallVector<int64_t> elementOffsets :
183 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
186 readOp.getPermutationMap(), loc, rewriter);
187 auto slicedRead = vector::TransferReadOp::create(
188 rewriter, loc, targetType, readOp.getBase(),
indices,
189 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190 readOp.getInBoundsAttr());
193 loc, slicedRead,
result, elementOffsets, strides);
200 vector::UnrollVectorOptions options;
203struct UnrollTransferWritePattern
205 UnrollTransferWritePattern(MLIRContext *context,
206 const vector::UnrollVectorOptions &options,
207 PatternBenefit benefit = 1)
208 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
211 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
212 PatternRewriter &rewriter)
const override {
214 if (writeOp.getTransferRank() == 0)
217 if (writeOp.getMask())
222 auto sourceVectorType = writeOp.getVectorType();
223 SmallVector<int64_t> strides(targetShape->size(), 1);
224 Location loc = writeOp.getLoc();
225 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
229 if (originalSize.size() != targetShape->size())
232 "expected source input vector rank to match target shape rank");
234 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
235 writeOp.getIndices().end());
236 SmallVector<int64_t> loopOrder =
239 for (SmallVector<int64_t> elementOffsets :
240 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
241 Value slicedVector = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
242 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
245 writeOp.getPermutationMap(), loc, rewriter);
246 Operation *slicedWrite = vector::TransferWriteOp::create(
247 rewriter, loc, slicedVector,
248 resultTensor ? resultTensor : writeOp.getBase(),
indices,
249 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
252 resultTensor = slicedWrite->
getResult(0);
255 rewriter.
replaceOp(writeOp, resultTensor);
262 vector::UnrollVectorOptions options;
265struct OffsetMapInfo {
266 static SmallVector<int64_t> getEmptyKey() {
return {int64_t(-1)}; }
268 static SmallVector<int64_t> getTombstoneKey() {
return {int64_t(-2)}; }
270 static unsigned getHashValue(
const SmallVector<int64_t> &v) {
271 return static_cast<unsigned>(llvm::hash_combine_range(v));
274 static bool isEqual(
const SmallVector<int64_t> &
lhs,
275 const SmallVector<int64_t> &
rhs) {
280struct UnrollContractionPattern
282 UnrollContractionPattern(MLIRContext *context,
283 const vector::UnrollVectorOptions &options,
284 PatternBenefit benefit = 1)
285 : OpRewritePattern<vector::ContractionOp>(context, benefit),
288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289 PatternRewriter &rewriter)
const override {
293 auto dstVecType = cast<VectorType>(contractOp.getResultType());
294 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
296 Location loc = contractOp.getLoc();
297 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
300 SmallVector<int64_t>, Value,
301 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
305 contractOp.getIteratorTypes().size(), contractOp, options);
307 for (SmallVector<int64_t> offsets :
308 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
309 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
312 auto extractOperand = [&](
unsigned index, Value operand,
313 AffineMap permutationMap,
314 ArrayRef<int64_t> operandOffets) {
316 permutationMap, ArrayRef<int64_t>(*targetShape));
317 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
318 slicesOperands[index] =
320 loc, operand, operandOffets, operandShape, operandStrides);
324 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
325 SmallVector<int64_t> lhsOffets =
327 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
330 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
331 SmallVector<int64_t> rhsOffets =
333 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
335 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
336 SmallVector<int64_t> accOffets =
340 auto *accIt = accCache.find(accOffets);
341 if (accIt != accCache.end())
342 slicesOperands[2] = accIt->second;
344 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
346 SmallVector<int64_t> dstShape =
348 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
350 rewriter, loc, contractOp, slicesOperands, targetType);
352 SmallVector<int64_t> dstOffets =
356 accCache[dstOffets] = newOp->
getResult(0);
359 Value
result = arith::ConstantOp::create(rewriter, loc, dstVecType,
361 for (
const auto &it : accCache) {
362 SmallVector<int64_t> dstStrides(it.first.size(), 1);
364 loc, it.second,
result, it.first, dstStrides);
371 vector::UnrollVectorOptions options;
374struct UnrollMultiReductionPattern
376 UnrollMultiReductionPattern(MLIRContext *context,
377 const vector::UnrollVectorOptions &options,
378 PatternBenefit benefit = 1)
379 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
382 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
383 PatternRewriter &rewriter)
const override {
384 auto resultType = reductionOp->getResult(0).getType();
385 if (resultType.isIntOrFloat()) {
387 "Unrolling scalars is not supported");
389 std::optional<SmallVector<int64_t>> targetShape =
393 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
395 SmallVector<int64_t>, Value,
396 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
398 Location loc = reductionOp.getLoc();
402 for (SmallVector<int64_t> offsets :
403 StaticTileOffsetRange(originalSize, *targetShape)) {
404 SmallVector<Value> operands;
405 SmallVector<int64_t> operandStrides(offsets.size(), 1);
406 Value slicedOperand =
408 loc, reductionOp.getSource(), offsets, *targetShape,
410 operands.push_back(slicedOperand);
411 SmallVector<int64_t> dstShape;
412 SmallVector<int64_t> destOffset;
413 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
414 if (!reductionOp.isReducedDim(i)) {
415 destOffset.push_back(offsets[i]);
416 dstShape.push_back((*targetShape)[i]);
420 SmallVector<int64_t> accStrides(destOffset.size(), 1);
423 auto *accIt = accCache.find(destOffset);
424 if (accIt != accCache.end())
427 acc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
428 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
429 operands.push_back(acc);
430 auto targetType = VectorType::get(
431 dstShape, reductionOp.getSourceVectorType().getElementType());
433 operands, targetType);
435 accCache[destOffset] =
result;
438 Value
result = arith::ConstantOp::create(
439 rewriter, loc, reductionOp.getDestType(),
441 for (
const auto &it : accCache) {
442 SmallVector<int64_t> dstStrides(it.first.size(), 1);
444 loc, it.second,
result, it.first, dstStrides);
451 vector::UnrollVectorOptions options;
455 UnrollElementwisePattern(MLIRContext *context,
456 const vector::UnrollVectorOptions &options,
457 PatternBenefit benefit = 1)
458 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
461 LogicalResult matchAndRewrite(Operation *op,
462 PatternRewriter &rewriter)
const override {
468 int64_t targetShapeRank = targetShape->size();
470 SmallVector<int64_t> originalSize =
471 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
472 int64_t originalShapeRank = originalSize.size();
474 Location loc = op->
getLoc();
477 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
478 int64_t rankDiff = originalShapeRank - targetShapeRank;
479 std::fill(adjustedTargetShape.begin(),
480 adjustedTargetShape.begin() + rankDiff, 1);
481 std::copy(targetShape->begin(), targetShape->end(),
482 adjustedTargetShape.begin() + rankDiff);
484 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
486 Value
result = arith::ConstantOp::create(rewriter, loc, dstVecType,
488 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
489 VectorType unrolledVecType =
490 VectorType::get(*targetShape, dstVecType.getElementType());
493 for (SmallVector<int64_t> offsets :
494 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
495 SmallVector<Value> extractOperands;
497 auto vecType = dyn_cast<VectorType>(operand.get().getType());
499 extractOperands.push_back(operand.get());
502 Value extracted = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
503 loc, operand.get(), offsets, adjustedTargetShape, strides);
506 if (adjustedTargetShapeRank > targetShapeRank) {
508 loc, VectorType::get(*targetShape, vecType.getElementType()),
511 extractOperands.push_back(extracted);
515 rewriter, loc, op, extractOperands, unrolledVecType);
517 Value computeResult = newOp->
getResult(0);
520 SmallVector<int64_t> insertStrides =
521 (adjustedTargetShapeRank > targetShapeRank)
522 ? SmallVector<int64_t>(targetShapeRank, 1)
526 loc, computeResult,
result, offsets, insertStrides);
533 vector::UnrollVectorOptions options;
536struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
537 UnrollReductionPattern(MLIRContext *context,
538 const vector::UnrollVectorOptions &options,
539 PatternBenefit benefit = 1)
540 : OpRewritePattern<vector::ReductionOp>(context, benefit),
543 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
544 PatternRewriter &rewriter)
const override {
545 std::optional<SmallVector<int64_t>> targetShape =
549 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
552 Location loc = reductionOp.getLoc();
553 Value accumulator =
nullptr;
554 for (SmallVector<int64_t> offsets :
555 StaticTileOffsetRange(originalSize, *targetShape)) {
556 SmallVector<int64_t> strides(offsets.size(), 1);
557 Value slicedOperand =
559 loc, reductionOp.getVector(), offsets, *targetShape, strides);
561 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
574 rewriter.
replaceOp(reductionOp, accumulator);
579 const vector::UnrollVectorOptions options;
582struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
583 UnrollTransposePattern(MLIRContext *context,
584 const vector::UnrollVectorOptions &options,
585 PatternBenefit benefit = 1)
586 : OpRewritePattern<vector::TransposeOp>(context, benefit),
589 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
590 PatternRewriter &rewriter)
const override {
591 if (transposeOp.getResultVectorType().getRank() == 0)
596 auto originalVectorType = transposeOp.getResultVectorType();
597 SmallVector<int64_t> strides(targetShape->size(), 1);
598 Location loc = transposeOp.getLoc();
599 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
603 arith::ConstantOp::create(rewriter, loc, originalVectorType,
605 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
608 for (SmallVector<int64_t> elementOffsets :
609 StaticTileOffsetRange(originalSize, *targetShape)) {
610 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
611 SmallVector<int64_t> permutedShape(elementOffsets.size());
613 for (
auto indices : llvm::enumerate(permutation)) {
614 permutedOffsets[
indices.value()] = elementOffsets[
indices.index()];
617 Value slicedOperand =
619 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
621 Value transposedSlice = rewriter.
createOrFold<vector::TransposeOp>(
622 loc, slicedOperand, permutation);
624 loc, transposedSlice,
result, elementOffsets, strides);
631 vector::UnrollVectorOptions options;
635 UnrollGatherPattern(MLIRContext *context,
636 const vector::UnrollVectorOptions &options,
637 PatternBenefit benefit = 1)
638 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
641 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
642 PatternRewriter &rewriter)
const override {
643 VectorType sourceVectorType = gatherOp.getVectorType();
644 if (sourceVectorType.getRank() == 0)
649 SmallVector<int64_t> strides(targetShape->size(), 1);
650 Location loc = gatherOp.getLoc();
651 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
655 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
658 VectorType::get(*targetShape, sourceVectorType.getElementType());
660 SmallVector<int64_t> loopOrder =
662 for (SmallVector<int64_t> elementOffsets :
663 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
667 Value indexSubVec = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
668 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
669 Value maskSubVec = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
670 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
671 Value passThruSubVec =
673 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
675 auto slicedGather = vector::GatherOp::create(
676 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
677 indexSubVec, maskSubVec, passThruSubVec);
680 loc, slicedGather,
result, elementOffsets, strides);
687 vector::UnrollVectorOptions options;
691 UnrollLoadPattern(MLIRContext *context,
692 const vector::UnrollVectorOptions &options,
693 PatternBenefit benefit = 1)
694 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
696 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
697 PatternRewriter &rewriter)
const override {
698 VectorType vecType = loadOp.getVectorType();
704 Location loc = loadOp.getLoc();
705 ArrayRef<int64_t> originalShape = vecType.getShape();
706 SmallVector<int64_t> strides(targetShape->size(), 1);
708 Value
result = arith::ConstantOp::create(rewriter, loc, vecType,
711 SmallVector<int64_t> loopOrder =
715 VectorType::get(*targetShape, vecType.getElementType());
717 for (SmallVector<int64_t> offsets :
718 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
721 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
724 loc, slicedLoad,
result, offsets, strides);
731 vector::UnrollVectorOptions options;
735 UnrollStorePattern(MLIRContext *context,
736 const vector::UnrollVectorOptions &options,
737 PatternBenefit benefit = 1)
738 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
740 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
741 PatternRewriter &rewriter)
const override {
742 VectorType vecType = storeOp.getVectorType();
748 Location loc = storeOp.getLoc();
749 ArrayRef<int64_t> originalShape = vecType.getShape();
750 SmallVector<int64_t> strides(targetShape->size(), 1);
752 Value base = storeOp.getBase();
753 Value vector = storeOp.getValueToStore();
755 SmallVector<int64_t> loopOrder =
758 for (SmallVector<int64_t> offsets :
759 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
762 Value slice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
763 loc, vector, offsets, *targetShape, strides);
764 vector::StoreOp::create(rewriter, loc, slice, base,
indices);
771 vector::UnrollVectorOptions options;
774struct UnrollBroadcastPattern :
public OpRewritePattern<vector::BroadcastOp> {
775 UnrollBroadcastPattern(MLIRContext *context,
776 const vector::UnrollVectorOptions &options,
777 PatternBenefit benefit = 1)
778 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
781 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
782 PatternRewriter &rewriter)
const override {
787 Location loc = broadcastOp.getLoc();
788 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
789 VectorType resType = broadcastOp.getResultVectorType();
790 VectorType targetType =
791 resType.cloneWith(*targetShape, resType.getElementType());
792 Value
result = arith::ConstantOp::create(rewriter, loc, resType,
795 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
796 SmallVector<int64_t> strides(originalShape.size(), 1);
798 for (SmallVector<int64_t> offsets :
799 StaticTileOffsetRange(originalShape, *targetShape)) {
803 newSrc = broadcastOp.getSource();
806 int64_t rank = srcType.getRank();
807 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
808 SmallVector<int64_t> srcShape(targetShape->end() - rank,
810 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
812 for (int64_t i = 0; i < rank; ++i) {
813 if (srcType.getDimSize(i) == 1) {
818 newSrc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
819 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
834 vector::UnrollVectorOptions options;
854struct UnrollToElements final :
public OpRewritePattern<vector::ToElementsOp> {
855 UnrollToElements(MLIRContext *context,
856 const vector::UnrollVectorOptions &options,
857 PatternBenefit benefit = 1)
858 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
861 LogicalResult matchAndRewrite(vector::ToElementsOp op,
862 PatternRewriter &rewriter)
const override {
865 FailureOr<SmallVector<Value>>
result =
870 SmallVector<Value> vectors = *
result;
872 SmallVector<Value> results;
873 for (Value vector : vectors) {
875 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
876 llvm::append_range(results, subElements.getResults());
883 vector::UnrollVectorOptions options;
913 UnrollStepPattern(MLIRContext *context,
914 const vector::UnrollVectorOptions &options,
915 PatternBenefit benefit = 1)
916 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
918 LogicalResult matchAndRewrite(vector::StepOp stepOp,
919 PatternRewriter &rewriter)
const override {
920 std::optional<SmallVector<int64_t>> targetShape =
925 VectorType vecType = stepOp.getType();
926 if (vecType.isScalable()) {
930 int64_t originalSize = vecType.getShape()[0];
931 Location loc = stepOp.getLoc();
932 SmallVector<int64_t> strides(1, 1);
934 Value
result = arith::ConstantOp::create(rewriter, loc, vecType,
938 VectorType::get(*targetShape, vecType.getElementType());
939 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
940 for (
const SmallVector<int64_t> &offsets :
941 StaticTileOffsetRange({originalSize}, *targetShape)) {
942 Value bcastOffset = arith::ConstantOp::create(
943 rewriter, loc, targetVecType,
946 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
948 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
951 loc, tileStep,
result, offsets, strides);
958 vector::UnrollVectorOptions options;
979 UnrollFromElements(MLIRContext *context,
980 const vector::UnrollVectorOptions &options,
981 PatternBenefit benefit = 1)
982 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
985 LogicalResult matchAndRewrite(vector::FromElementsOp op,
986 PatternRewriter &rewriter)
const override {
989 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
990 VectorType subTy, int64_t index) {
991 size_t subTyNumElements = subTy.getNumElements();
992 assert((index + 1) * subTyNumElements <= allElements.size() &&
995 allElements.slice(index * subTyNumElements, subTyNumElements);
996 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
1003 vector::UnrollVectorOptions options;
1039struct UnrollCreateMaskPattern :
public OpRewritePattern<vector::CreateMaskOp> {
1040 UnrollCreateMaskPattern(MLIRContext *context,
1041 const vector::UnrollVectorOptions &options,
1042 PatternBenefit benefit = 1)
1043 : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1046 LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
1047 PatternRewriter &rewriter)
const override {
1052 VectorType resultType = createMaskOp.getVectorType();
1053 SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
1054 Location loc = createMaskOp.getLoc();
1056 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1058 VectorType targetVectorType =
1059 VectorType::get(*targetShape, rewriter.
getI1Type());
1060 SmallVector<int64_t> strides(targetShape->size(), 1);
1064 for (SmallVector<int64_t> offsets :
1065 StaticTileOffsetRange(originalSize, *targetShape)) {
1066 SmallVector<Value> unrolledOperands;
1068 for (
auto [i, originalMaskOperand] :
1069 llvm::enumerate(createMaskOp.getOperands())) {
1072 Value adjustedMaskSize = rewriter.
createOrFold<arith::SubIOp>(
1073 loc, originalMaskOperand, offsetVal);
1075 Value unrolledDimSize =
1078 rewriter.
createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1079 Value unrolledOperand = rewriter.
createOrFold<arith::MinSIOp>(
1080 loc, nonNegative, unrolledDimSize);
1081 unrolledOperands.push_back(unrolledOperand);
1084 auto unrolledMask = rewriter.
createOrFold<vector::CreateMaskOp>(
1085 loc, targetVectorType, unrolledOperands);
1087 loc, unrolledMask,
result, offsets, strides);
1094 vector::UnrollVectorOptions options;
1129struct UnrollConstantMaskPattern
1131 UnrollConstantMaskPattern(MLIRContext *context,
1132 const vector::UnrollVectorOptions &options,
1133 PatternBenefit benefit = 1)
1134 : OpRewritePattern<vector::ConstantMaskOp>(context, benefit),
1137 LogicalResult matchAndRewrite(vector::ConstantMaskOp constantMaskOp,
1138 PatternRewriter &rewriter)
const override {
1139 std::optional<SmallVector<int64_t>> targetShape =
1144 VectorType resultType = constantMaskOp.getVectorType();
1145 SmallVector<int64_t> originalSize = *constantMaskOp.getShapeForUnroll();
1146 Location loc = constantMaskOp.getLoc();
1148 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1150 VectorType targetVectorType =
1151 VectorType::get(*targetShape, rewriter.
getI1Type());
1152 SmallVector<int64_t> strides(targetShape->size(), 1);
1156 for (
const SmallVector<int64_t> &offsets :
1157 StaticTileOffsetRange(originalSize, *targetShape)) {
1158 SmallVector<int64_t> unrolledMaskDims;
1160 for (
auto [i, originalMaskDim] :
1161 llvm::enumerate(constantMaskOp.getMaskDimSizes())) {
1164 int64_t adjustedMaskSize =
1165 std::max(originalMaskDim - offsets[i],
static_cast<int64_t
>(0));
1166 int64_t unrolledMaskDim =
1167 std::min(adjustedMaskSize,
static_cast<int64_t
>((*targetShape)[i]));
1168 unrolledMaskDims.push_back(unrolledMaskDim);
1171 auto unrolledMask = rewriter.
createOrFold<vector::ConstantMaskOp>(
1172 loc, targetVectorType, unrolledMaskDims);
1174 loc, unrolledMask,
result, offsets, strides);
1181 vector::UnrollVectorOptions options;
1199 if (extractShape.size() >
shape.size())
1202 while (!extractShape.empty() && extractShape.front() == 1) {
1203 extractShape = extractShape.drop_front();
1206 while (!
shape.empty() &&
shape.front() == 1) {
1210 size_t rankDiff =
shape.size() - extractShape.size();
1211 if (!llvm::equal(extractShape.drop_front(),
shape.drop_front(rankDiff + 1)))
1214 int64_t extractElements = ShapedType::getNumElements(extractShape);
1215 int64_t shapeElements = ShapedType::getNumElements(
shape);
1216 return shapeElements % extractElements == 0;
1238static std::optional<SmallVector<int64_t>>
1242 int64_t remainingElements = targetElements;
1245 for (
int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1246 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1247 extractShape.insert(extractShape.begin(), takeFromDim);
1249 if (remainingElements % takeFromDim != 0)
1250 return std::nullopt;
1251 remainingElements /= takeFromDim;
1255 while (extractShape.size() < sourceShape.size())
1256 extractShape.insert(extractShape.begin(), 1);
1258 if (ShapedType::getNumElements(extractShape) != targetElements)
1259 return std::nullopt;
1261 return extractShape;
1306struct UnrollShapeCastPattern :
public OpRewritePattern<vector::ShapeCastOp> {
1307 UnrollShapeCastPattern(MLIRContext *context,
1308 const vector::UnrollVectorOptions &options,
1309 PatternBenefit benefit = 1)
1310 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1313 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1314 PatternRewriter &rewriter)
const override {
1315 std::optional<SmallVector<int64_t>> targetShape =
1320 VectorType sourceType = shapeCastOp.getSourceVectorType();
1321 VectorType resultType = shapeCastOp.getResultVectorType();
1322 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1323 ArrayRef<int64_t> resultShape = resultType.getShape();
1325 if (!isContiguous(*targetShape, resultShape))
1327 shapeCastOp,
"Only supports cases where target shape is "
1328 "contiguous in result vector shape");
1330 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1333 std::optional<SmallVector<int64_t>> extractShape =
1334 calculateSourceExtractShape(sourceShape, targetElements);
1338 "cannot extract target number of elements contiguously from source");
1340 Location loc = shapeCastOp.getLoc();
1343 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1346 VectorType targetType =
1347 VectorType::get(*targetShape, sourceType.getElementType());
1350 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1352 for (SmallVector<int64_t> resultOffsets :
1353 StaticTileOffsetRange(resultShape, *targetShape)) {
1354 SmallVector<int64_t> sourceOffsets =
1355 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1356 Value sourceChunk = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1357 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1359 Value targetChunk = rewriter.
createOrFold<vector::ShapeCastOp>(
1360 loc, targetType, sourceChunk);
1362 loc, targetChunk,
result, resultOffsets, insertStrides);
1370 vector::UnrollVectorOptions options;
1375void mlir::vector::populateVectorUnrollPatterns(
1378 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1379 UnrollContractionPattern, UnrollElementwisePattern,
1380 UnrollReductionPattern, UnrollMultiReductionPattern,
1381 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1382 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1383 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1384 UnrollCreateMaskPattern, UnrollConstantMaskPattern>(
1388void mlir::vector::populateVectorToElementsUnrollPatterns(
1394void mlir::vector::populateVectorFromElementsUnrollPatterns(
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
OperationName getName()
The name of an operation is the key identifier for it.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.