18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/InterleavedRange.h"
24#define DEBUG_TYPE "vector-unroll"
37 if (
auto constExpr = dyn_cast<AffineConstantExpr>(expr))
38 return constExpr.getValue() == 0;
43 for (
const auto &dim : llvm::enumerate(permutationMap.
getResults())) {
44 if (isBroadcast(dim.value()))
46 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition();
51 affine::AffineApplyOp::create(builder, loc, map,
indices[pos]);
63 assert(offsets.size() <= originalIndices.size() &&
64 "Offsets should not exceed the number of original indices");
67 auto start =
indices.size() - offsets.size();
68 for (
auto [i, offset] : llvm::enumerate(offsets)) {
70 indices[start + i] = arith::AddIOp::create(
71 rewriter, loc, originalIndices[start + i],
90static std::optional<SmallVector<int64_t>>
93 if (
options.filterConstraint && failed(
options.filterConstraint(op))) {
94 LDBG() <<
"--no filter constraint -> BAIL";
98 "vector unrolling expects the native shape or native"
99 "shape call back function to be set");
100 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op);
101 if (!unrollableVectorOp) {
102 LDBG() <<
"--not an unrollable op -> BAIL";
105 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll();
106 if (!maybeUnrollShape) {
107 LDBG() <<
"--could not get shape of op " << *op <<
" -> BAIL";
110 LDBG() <<
"--vector op shape: " << llvm::interleaved(*maybeUnrollShape);
112 std::optional<SmallVector<int64_t>> targetShape =
options.nativeShape(op);
114 LDBG() <<
"--no unrolling target shape defined " << *op <<
"-> SKIP";
117 LDBG() <<
"--target shape: " << llvm::interleaved(*targetShape);
120 if (!maybeShapeRatio) {
121 LDBG() <<
"--could not compute integral shape ratio -> BAIL";
124 if (llvm::all_of(*maybeShapeRatio, [](
int64_t v) {
return v == 1; })) {
125 LDBG() <<
"--no unrolling needed -> SKIP";
128 LDBG() <<
"--found an integral shape ratio to unroll to -> SUCCESS";
136 llvm::to_vector(llvm::seq<int64_t>(0,
static_cast<int64_t>(numLoops)));
137 if (
options.traversalOrderCallback !=
nullptr) {
138 std::optional<SmallVector<int64_t>> order =
139 options.traversalOrderCallback(op);
141 loopOrder = std::move(*order);
149struct UnrollTransferReadPattern
151 UnrollTransferReadPattern(MLIRContext *context,
152 const vector::UnrollVectorOptions &options,
153 PatternBenefit benefit = 1)
154 : OpRewritePattern<vector::TransferReadOp>(context, benefit),
157 LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
158 PatternRewriter &rewriter)
const override {
160 if (readOp.getTransferRank() == 0)
162 if (readOp.getMask())
167 auto sourceVectorType = readOp.getVectorType();
168 SmallVector<int64_t> strides(targetShape->size(), 1);
169 Location loc = readOp.getLoc();
170 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
174 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
177 VectorType::get(*targetShape, sourceVectorType.getElementType());
178 SmallVector<Value> originalIndices(readOp.getIndices().begin(),
179 readOp.getIndices().end());
180 SmallVector<int64_t> loopOrder =
182 for (SmallVector<int64_t> elementOffsets :
183 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
186 readOp.getPermutationMap(), loc, rewriter);
187 auto slicedRead = vector::TransferReadOp::create(
188 rewriter, loc, targetType, readOp.getBase(),
indices,
189 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
190 readOp.getInBoundsAttr());
193 loc, slicedRead,
result, elementOffsets, strides);
200 vector::UnrollVectorOptions options;
203struct UnrollTransferWritePattern
205 UnrollTransferWritePattern(MLIRContext *context,
206 const vector::UnrollVectorOptions &options,
207 PatternBenefit benefit = 1)
208 : OpRewritePattern<vector::TransferWriteOp>(context, benefit),
211 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
212 PatternRewriter &rewriter)
const override {
214 if (writeOp.getTransferRank() == 0)
217 if (writeOp.getMask())
222 auto sourceVectorType = writeOp.getVectorType();
223 SmallVector<int64_t> strides(targetShape->size(), 1);
224 Location loc = writeOp.getLoc();
225 ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
229 if (originalSize.size() != targetShape->size())
232 "expected source input vector rank to match target shape rank");
234 SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
235 writeOp.getIndices().end());
236 SmallVector<int64_t> loopOrder =
239 for (SmallVector<int64_t> elementOffsets :
240 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
241 Value slicedVector = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
242 loc, writeOp.getVector(), elementOffsets, *targetShape, strides);
245 writeOp.getPermutationMap(), loc, rewriter);
246 Operation *slicedWrite = vector::TransferWriteOp::create(
247 rewriter, loc, slicedVector,
248 resultTensor ? resultTensor : writeOp.getBase(),
indices,
249 writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
252 resultTensor = slicedWrite->
getResult(0);
255 rewriter.
replaceOp(writeOp, resultTensor);
262 vector::UnrollVectorOptions options;
265struct OffsetMapInfo {
266 static SmallVector<int64_t> getEmptyKey() {
return {int64_t(-1)}; }
268 static SmallVector<int64_t> getTombstoneKey() {
return {int64_t(-2)}; }
270 static unsigned getHashValue(
const SmallVector<int64_t> &v) {
271 return static_cast<unsigned>(llvm::hash_combine_range(v));
274 static bool isEqual(
const SmallVector<int64_t> &
lhs,
275 const SmallVector<int64_t> &
rhs) {
280struct UnrollContractionPattern
282 UnrollContractionPattern(MLIRContext *context,
283 const vector::UnrollVectorOptions &options,
284 PatternBenefit benefit = 1)
285 : OpRewritePattern<vector::ContractionOp>(context, benefit),
288 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289 PatternRewriter &rewriter)
const override {
293 auto dstVecType = cast<VectorType>(contractOp.getResultType());
294 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll();
296 Location loc = contractOp.getLoc();
297 unsigned accIndex = vector::ContractionOp::getAccOperandIndex();
298 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex];
300 SmallVector<int64_t>, Value,
301 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
305 contractOp.getIteratorTypes().size(), contractOp, options);
307 for (SmallVector<int64_t> offsets :
308 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
309 SmallVector<Value> slicesOperands(contractOp.getNumOperands());
312 auto extractOperand = [&](
unsigned index, Value operand,
313 AffineMap permutationMap,
314 ArrayRef<int64_t> operandOffets) {
316 permutationMap, ArrayRef<int64_t>(*targetShape));
317 SmallVector<int64_t> operandStrides(operandOffets.size(), 1);
318 slicesOperands[index] =
320 loc, operand, operandOffets, operandShape, operandStrides);
324 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0];
325 SmallVector<int64_t> lhsOffets =
327 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets);
330 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1];
331 SmallVector<int64_t> rhsOffets =
333 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets);
335 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2];
336 SmallVector<int64_t> accOffets =
340 auto *accIt = accCache.find(accOffets);
341 if (accIt != accCache.end())
342 slicesOperands[2] = accIt->second;
344 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets);
346 SmallVector<int64_t> dstShape =
348 auto targetType = VectorType::get(dstShape, dstVecType.getElementType());
350 rewriter, loc, contractOp, slicesOperands, targetType);
352 SmallVector<int64_t> dstOffets =
356 accCache[dstOffets] = newOp->
getResult(0);
359 Value
result = arith::ConstantOp::create(rewriter, loc, dstVecType,
361 for (
const auto &it : accCache) {
362 SmallVector<int64_t> dstStrides(it.first.size(), 1);
364 loc, it.second,
result, it.first, dstStrides);
371 vector::UnrollVectorOptions options;
374struct UnrollMultiReductionPattern
376 UnrollMultiReductionPattern(MLIRContext *context,
377 const vector::UnrollVectorOptions &options,
378 PatternBenefit benefit = 1)
379 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit),
382 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
383 PatternRewriter &rewriter)
const override {
384 auto resultType = reductionOp->getResult(0).getType();
385 if (resultType.isIntOrFloat()) {
387 "Unrolling scalars is not supported");
389 std::optional<SmallVector<int64_t>> targetShape =
393 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
395 SmallVector<int64_t>, Value,
396 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
398 Location loc = reductionOp.getLoc();
402 for (SmallVector<int64_t> offsets :
403 StaticTileOffsetRange(originalSize, *targetShape)) {
404 SmallVector<Value> operands;
405 SmallVector<int64_t> operandStrides(offsets.size(), 1);
406 Value slicedOperand =
408 loc, reductionOp.getSource(), offsets, *targetShape,
410 operands.push_back(slicedOperand);
411 SmallVector<int64_t> dstShape;
412 SmallVector<int64_t> destOffset;
413 for (
size_t i : llvm::seq(
size_t(0), targetShape->size())) {
414 if (!reductionOp.isReducedDim(i)) {
415 destOffset.push_back(offsets[i]);
416 dstShape.push_back((*targetShape)[i]);
420 SmallVector<int64_t> accStrides(destOffset.size(), 1);
423 auto *accIt = accCache.find(destOffset);
424 if (accIt != accCache.end())
427 acc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
428 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides);
429 operands.push_back(acc);
430 auto targetType = VectorType::get(
431 dstShape, reductionOp.getSourceVectorType().getElementType());
433 operands, targetType);
435 accCache[destOffset] =
result;
438 Value
result = arith::ConstantOp::create(
439 rewriter, loc, reductionOp.getDestType(),
441 for (
const auto &it : accCache) {
442 SmallVector<int64_t> dstStrides(it.first.size(), 1);
444 loc, it.second,
result, it.first, dstStrides);
451 vector::UnrollVectorOptions options;
455 UnrollElementwisePattern(MLIRContext *context,
456 const vector::UnrollVectorOptions &options,
457 PatternBenefit benefit = 1)
458 : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
461 LogicalResult matchAndRewrite(Operation *op,
462 PatternRewriter &rewriter)
const override {
468 int64_t targetShapeRank = targetShape->size();
470 SmallVector<int64_t> originalSize =
471 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
472 int64_t originalShapeRank = originalSize.size();
474 Location loc = op->
getLoc();
477 SmallVector<int64_t> adjustedTargetShape(originalShapeRank);
478 int64_t rankDiff = originalShapeRank - targetShapeRank;
479 std::fill(adjustedTargetShape.begin(),
480 adjustedTargetShape.begin() + rankDiff, 1);
481 std::copy(targetShape->begin(), targetShape->end(),
482 adjustedTargetShape.begin() + rankDiff);
484 int64_t adjustedTargetShapeRank = adjustedTargetShape.size();
486 Value
result = arith::ConstantOp::create(rewriter, loc, dstVecType,
488 SmallVector<int64_t> strides(adjustedTargetShapeRank, 1);
489 VectorType unrolledVecType =
490 VectorType::get(*targetShape, dstVecType.getElementType());
493 for (SmallVector<int64_t> offsets :
494 StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
495 SmallVector<Value> extractOperands;
497 auto vecType = dyn_cast<VectorType>(operand.get().getType());
499 extractOperands.push_back(operand.get());
502 Value extracted = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
503 loc, operand.get(), offsets, adjustedTargetShape, strides);
506 if (adjustedTargetShapeRank > targetShapeRank) {
508 loc, VectorType::get(*targetShape, vecType.getElementType()),
511 extractOperands.push_back(extracted);
515 rewriter, loc, op, extractOperands, unrolledVecType);
517 Value computeResult = newOp->
getResult(0);
520 SmallVector<int64_t> insertStrides =
521 (adjustedTargetShapeRank > targetShapeRank)
522 ? SmallVector<int64_t>(targetShapeRank, 1)
526 loc, computeResult,
result, offsets, insertStrides);
533 vector::UnrollVectorOptions options;
536struct UnrollReductionPattern :
public OpRewritePattern<vector::ReductionOp> {
537 UnrollReductionPattern(MLIRContext *context,
538 const vector::UnrollVectorOptions &options,
539 PatternBenefit benefit = 1)
540 : OpRewritePattern<vector::ReductionOp>(context, benefit),
543 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp,
544 PatternRewriter &rewriter)
const override {
545 std::optional<SmallVector<int64_t>> targetShape =
549 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
552 Location loc = reductionOp.getLoc();
553 Value accumulator =
nullptr;
554 for (SmallVector<int64_t> offsets :
555 StaticTileOffsetRange(originalSize, *targetShape)) {
556 SmallVector<int64_t> strides(offsets.size(), 1);
557 Value slicedOperand =
559 loc, reductionOp.getVector(), offsets, *targetShape, strides);
561 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType());
574 rewriter.
replaceOp(reductionOp, accumulator);
579 const vector::UnrollVectorOptions options;
582struct UnrollTransposePattern :
public OpRewritePattern<vector::TransposeOp> {
583 UnrollTransposePattern(MLIRContext *context,
584 const vector::UnrollVectorOptions &options,
585 PatternBenefit benefit = 1)
586 : OpRewritePattern<vector::TransposeOp>(context, benefit),
589 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
590 PatternRewriter &rewriter)
const override {
591 if (transposeOp.getResultVectorType().getRank() == 0)
596 auto originalVectorType = transposeOp.getResultVectorType();
597 SmallVector<int64_t> strides(targetShape->size(), 1);
598 Location loc = transposeOp.getLoc();
599 ArrayRef<int64_t> originalSize = originalVectorType.getShape();
603 arith::ConstantOp::create(rewriter, loc, originalVectorType,
605 ArrayRef<int64_t> permutation = transposeOp.getPermutation();
608 for (SmallVector<int64_t> elementOffsets :
609 StaticTileOffsetRange(originalSize, *targetShape)) {
610 SmallVector<int64_t> permutedOffsets(elementOffsets.size());
611 SmallVector<int64_t> permutedShape(elementOffsets.size());
613 for (
auto indices : llvm::enumerate(permutation)) {
614 permutedOffsets[
indices.value()] = elementOffsets[
indices.index()];
617 Value slicedOperand =
619 loc, transposeOp.getVector(), permutedOffsets, permutedShape,
621 Value transposedSlice = rewriter.
createOrFold<vector::TransposeOp>(
622 loc, slicedOperand, permutation);
624 loc, transposedSlice,
result, elementOffsets, strides);
631 vector::UnrollVectorOptions options;
635 UnrollGatherPattern(MLIRContext *context,
636 const vector::UnrollVectorOptions &options,
637 PatternBenefit benefit = 1)
638 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) {
641 LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
642 PatternRewriter &rewriter)
const override {
643 VectorType sourceVectorType = gatherOp.getVectorType();
644 if (sourceVectorType.getRank() == 0)
649 SmallVector<int64_t> strides(targetShape->size(), 1);
650 Location loc = gatherOp.getLoc();
651 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape();
655 arith::ConstantOp::create(rewriter, loc, sourceVectorType,
658 VectorType::get(*targetShape, sourceVectorType.getElementType());
660 SmallVector<int64_t> loopOrder =
662 for (SmallVector<int64_t> elementOffsets :
663 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) {
667 Value indexSubVec = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
668 loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides);
669 Value maskSubVec = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
670 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides);
671 Value passThruSubVec =
673 loc, gatherOp.getPassThru(), elementOffsets, *targetShape,
675 auto slicedGather = vector::GatherOp::create(
676 rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(),
677 indexSubVec, maskSubVec, passThruSubVec);
680 loc, slicedGather,
result, elementOffsets, strides);
687 vector::UnrollVectorOptions options;
691 UnrollLoadPattern(MLIRContext *context,
692 const vector::UnrollVectorOptions &options,
693 PatternBenefit benefit = 1)
694 : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
696 LogicalResult matchAndRewrite(vector::LoadOp loadOp,
697 PatternRewriter &rewriter)
const override {
698 VectorType vecType = loadOp.getVectorType();
704 Location loc = loadOp.getLoc();
705 ArrayRef<int64_t> originalShape = vecType.getShape();
706 SmallVector<int64_t> strides(targetShape->size(), 1);
708 Value
result = arith::ConstantOp::create(rewriter, loc, vecType,
711 SmallVector<int64_t> loopOrder =
715 VectorType::get(*targetShape, vecType.getElementType());
717 for (SmallVector<int64_t> offsets :
718 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
721 Value slicedLoad = vector::LoadOp::create(rewriter, loc, targetVecType,
724 loc, slicedLoad,
result, offsets, strides);
731 vector::UnrollVectorOptions options;
735 UnrollStorePattern(MLIRContext *context,
736 const vector::UnrollVectorOptions &options,
737 PatternBenefit benefit = 1)
738 : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
740 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
741 PatternRewriter &rewriter)
const override {
742 VectorType vecType = storeOp.getVectorType();
748 Location loc = storeOp.getLoc();
749 ArrayRef<int64_t> originalShape = vecType.getShape();
750 SmallVector<int64_t> strides(targetShape->size(), 1);
752 Value base = storeOp.getBase();
753 Value vector = storeOp.getValueToStore();
755 SmallVector<int64_t> loopOrder =
758 for (SmallVector<int64_t> offsets :
759 StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
762 Value slice = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
763 loc, vector, offsets, *targetShape, strides);
764 vector::StoreOp::create(rewriter, loc, slice, base,
indices);
771 vector::UnrollVectorOptions options;
774struct UnrollBroadcastPattern :
public OpRewritePattern<vector::BroadcastOp> {
775 UnrollBroadcastPattern(MLIRContext *context,
776 const vector::UnrollVectorOptions &options,
777 PatternBenefit benefit = 1)
778 : OpRewritePattern<vector::BroadcastOp>(context, benefit),
781 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
782 PatternRewriter &rewriter)
const override {
787 Location loc = broadcastOp.getLoc();
788 VectorType srcType = dyn_cast<VectorType>(broadcastOp.getSourceType());
789 VectorType resType = broadcastOp.getResultVectorType();
790 VectorType targetType =
791 resType.cloneWith(*targetShape, resType.getElementType());
792 Value
result = arith::ConstantOp::create(rewriter, loc, resType,
795 SmallVector<int64_t> originalShape = *broadcastOp.getShapeForUnroll();
796 SmallVector<int64_t> strides(originalShape.size(), 1);
798 for (SmallVector<int64_t> offsets :
799 StaticTileOffsetRange(originalShape, *targetShape)) {
803 newSrc = broadcastOp.getSource();
806 int64_t rank = srcType.getRank();
807 SmallVector<int64_t> srcOffsets(offsets.end() - rank, offsets.end());
808 SmallVector<int64_t> srcShape(targetShape->end() - rank,
810 SmallVector<int64_t> srcStrides(strides.end() - rank, strides.end());
812 for (int64_t i = 0; i < rank; ++i) {
813 if (srcType.getDimSize(i) == 1) {
818 newSrc = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
819 loc, broadcastOp.getSource(), srcOffsets, srcShape, srcStrides);
834 vector::UnrollVectorOptions options;
854struct UnrollToElements final :
public OpRewritePattern<vector::ToElementsOp> {
855 UnrollToElements(MLIRContext *context,
856 const vector::UnrollVectorOptions &options,
857 PatternBenefit benefit = 1)
858 : OpRewritePattern<vector::ToElementsOp>(context, benefit),
861 LogicalResult matchAndRewrite(vector::ToElementsOp op,
862 PatternRewriter &rewriter)
const override {
865 FailureOr<SmallVector<Value>>
result =
870 SmallVector<Value> vectors = *
result;
872 SmallVector<Value> results;
873 for (Value vector : vectors) {
875 vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
876 llvm::append_range(results, subElements.getResults());
883 vector::UnrollVectorOptions options;
913 UnrollStepPattern(MLIRContext *context,
914 const vector::UnrollVectorOptions &options,
915 PatternBenefit benefit = 1)
916 : OpRewritePattern<vector::StepOp>(context, benefit), options(options) {}
918 LogicalResult matchAndRewrite(vector::StepOp stepOp,
919 PatternRewriter &rewriter)
const override {
920 std::optional<SmallVector<int64_t>> targetShape =
925 VectorType vecType = stepOp.getType();
926 if (vecType.isScalable()) {
930 int64_t originalSize = vecType.getShape()[0];
931 Location loc = stepOp.getLoc();
932 SmallVector<int64_t> strides(1, 1);
934 Value
result = arith::ConstantOp::create(rewriter, loc, vecType,
938 VectorType::get(*targetShape, vecType.getElementType());
939 Value baseStep = vector::StepOp::create(rewriter, loc, targetVecType);
940 for (
const SmallVector<int64_t> &offsets :
941 StaticTileOffsetRange({originalSize}, *targetShape)) {
942 Value bcastOffset = arith::ConstantOp::create(
943 rewriter, loc, targetVecType,
946 IntegerAttr::get(targetVecType.getElementType(), offsets[0])));
948 arith::AddIOp::create(rewriter, loc, baseStep, bcastOffset);
951 loc, tileStep,
result, offsets, strides);
958 vector::UnrollVectorOptions options;
979 UnrollFromElements(MLIRContext *context,
980 const vector::UnrollVectorOptions &options,
981 PatternBenefit benefit = 1)
982 : OpRewritePattern<vector::FromElementsOp>(context, benefit),
985 LogicalResult matchAndRewrite(vector::FromElementsOp op,
986 PatternRewriter &rewriter)
const override {
989 auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
990 VectorType subTy, int64_t index) {
991 size_t subTyNumElements = subTy.getNumElements();
992 assert((index + 1) * subTyNumElements <= allElements.size() &&
995 allElements.slice(index * subTyNumElements, subTyNumElements);
996 return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
1003 vector::UnrollVectorOptions options;
1021 if (extractShape.size() >
shape.size())
1024 while (!extractShape.empty() && extractShape.front() == 1) {
1025 extractShape = extractShape.drop_front();
1028 while (!
shape.empty() &&
shape.front() == 1) {
1032 size_t rankDiff =
shape.size() - extractShape.size();
1033 if (!llvm::equal(extractShape.drop_front(),
shape.drop_front(rankDiff + 1)))
1036 int64_t extractElements = ShapedType::getNumElements(extractShape);
1037 int64_t shapeElements = ShapedType::getNumElements(
shape);
1038 return shapeElements % extractElements == 0;
1060static std::optional<SmallVector<int64_t>>
1064 int64_t remainingElements = targetElements;
1067 for (
int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
1068 int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
1069 extractShape.insert(extractShape.begin(), takeFromDim);
1071 if (remainingElements % takeFromDim != 0)
1072 return std::nullopt;
1073 remainingElements /= takeFromDim;
1077 while (extractShape.size() < sourceShape.size())
1078 extractShape.insert(extractShape.begin(), 1);
1080 if (ShapedType::getNumElements(extractShape) != targetElements)
1081 return std::nullopt;
1083 return extractShape;
1128struct UnrollShapeCastPattern :
public OpRewritePattern<vector::ShapeCastOp> {
1129 UnrollShapeCastPattern(MLIRContext *context,
1130 const vector::UnrollVectorOptions &options,
1131 PatternBenefit benefit = 1)
1132 : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1135 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
1136 PatternRewriter &rewriter)
const override {
1137 std::optional<SmallVector<int64_t>> targetShape =
1142 VectorType sourceType = shapeCastOp.getSourceVectorType();
1143 VectorType resultType = shapeCastOp.getResultVectorType();
1144 ArrayRef<int64_t> sourceShape = sourceType.getShape();
1145 ArrayRef<int64_t> resultShape = resultType.getShape();
1147 if (!isContiguous(*targetShape, resultShape))
1149 shapeCastOp,
"Only supports cases where target shape is "
1150 "contiguous in result vector shape");
1152 int64_t targetElements = ShapedType::getNumElements(*targetShape);
1155 std::optional<SmallVector<int64_t>> extractShape =
1156 calculateSourceExtractShape(sourceShape, targetElements);
1160 "cannot extract target number of elements contiguously from source");
1162 Location loc = shapeCastOp.getLoc();
1165 Value
result = arith::ConstantOp::create(rewriter, loc, resultType,
1168 VectorType targetType =
1169 VectorType::get(*targetShape, sourceType.getElementType());
1172 SmallVector<int64_t> insertStrides(targetShape->size(), 1);
1174 for (SmallVector<int64_t> resultOffsets :
1175 StaticTileOffsetRange(resultShape, *targetShape)) {
1176 SmallVector<int64_t> sourceOffsets =
1177 calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
1178 Value sourceChunk = rewriter.
createOrFold<vector::ExtractStridedSliceOp>(
1179 loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
1181 Value targetChunk = rewriter.
createOrFold<vector::ShapeCastOp>(
1182 loc, targetType, sourceChunk);
1184 loc, targetChunk,
result, resultOffsets, insertStrides);
1192 vector::UnrollVectorOptions options;
1197void mlir::vector::populateVectorUnrollPatterns(
1200 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
1201 UnrollContractionPattern, UnrollElementwisePattern,
1202 UnrollReductionPattern, UnrollMultiReductionPattern,
1203 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
1204 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1205 UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
1209void mlir::vector::populateVectorToElementsUnrollPatterns(
1215void mlir::vector::populateVectorFromElementsUnrollPatterns(
static LogicalResult extractStrides(AffineExpr e, AffineExpr multiplicativeFactor, MutableArrayRef< AffineExpr > strides, AffineExpr &offset)
Takes a single AffineExpr e and populates the strides array with the strides expressions for each dim...
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< Value > sliceLoadStoreIndices(PatternRewriter &rewriter, Location loc, OperandRange originalIndices, ArrayRef< int64_t > offsets)
static SmallVector< Value > sliceTransferIndices(ArrayRef< int64_t > elementOffsets, ArrayRef< Value > indices, AffineMap permutationMap, Location loc, OpBuilder &builder)
Compute the indices of the slice index for a transfer op.
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
static SmallVector< int64_t > getUnrollOrder(unsigned numLoops, Operation *op, const vector::UnrollVectorOptions &options)
static Operation * cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, Operation *op, ArrayRef< Value > operands, ArrayRef< Type > resultTypes)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class implements the operand iterators for the Operation class.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< OpOperand > getOpOperands()
OperationName getName()
The name of an operation is the key identifier for it.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
FailureOr< SmallVector< Value > > unrollVectorValue(TypedValue< VectorType >, RewriterBase &)
Generic utility for unrolling values of type vector<NxAxBx...> to N values of type vector<AxBx....
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.