33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
40 #define DEBUG_TYPE "linalg-transforms"
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 #define DBGSNL() (llvm::dbgs() << "\n")
62 .Case<scf::ForOp>([&](scf::ForOp forOp) {
63 scf::ForOp partialIteration;
66 return partialIteration->getResults();
67 assert(!partialIteration &&
"expected that loop was not peeled");
68 return forOp->getResults();
77 for (
auto loopOp : loops)
90 if (!e.isFunctionOfDim(dim))
152 int64_t newDim = iteratorTypes.size();
153 iteratorTypes.push_back(iteratorTypes[dim]);
156 indexingMaps.size(), std::nullopt);
158 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
160 AffineMap map = indexingMaps[operandIdx];
163 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
171 "num results invariant violation");
173 if (!maybeOperandDimensionToPack.has_value()) {
174 newMaps.push_back(map);
179 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
185 newMaps.push_back(map);
188 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
190 indexingMaps = newMaps;
192 return packedDimPerIndexingMap;
198 struct PackedOperandsDim {
204 struct PackedOperandsDimList {
205 void pushBack(PackedOperandsDim &&packedOperandsDims) {
206 spec.emplace_back(packedOperandsDims);
220 tensor::PackOp packOp) {
222 auto packedTensorType =
223 cast<RankedTensorType>(packOp->getResultTypes().front());
224 if (llvm::any_of(packOp.getStaticInnerTiles(),
225 [](int64_t size) { return ShapedType::isDynamic(size); })) {
228 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
237 PackingMetadata packingMetadata = computePackingMetadata(
238 packedTensorType.getRank(), packOp.getInnerDimsPos());
252 for (
auto [pos, innerSize] :
253 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
255 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
265 rewriter, loc, map, {outerSize, origSize, innerSize});
267 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
269 packingMetadata.reassociations);
270 Value paddingValue = packOp.getPaddingValue();
272 paddingValue = rewriter.
create<arith::ConstantOp>(
276 rewriter.
create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
277 highs, paddingValue,
false);
280 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
281 DBGS() <<
"insertPositions: ");
282 DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
283 DBGS() <<
"outerPositions: ");
284 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
285 DBGS() <<
"packedShape: ");
287 llvm::interleaveComma(packedToStripMinedShapePerm,
288 DBGS() <<
"packedToStripMinedShapePerm: ");
289 DBGSNL(); llvm::interleaveComma(
290 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
292 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
295 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
298 if (packOp.isLikePad()) {
319 auto insertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
324 LLVM_DEBUG(
DBGS() <<
"insert_slice op: " << insertSliceOp;
DBGSNL(););
326 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
333 auto reshapeOp = rewriter.
create<tensor::ExpandShapeOp>(
336 padOp.getResult(), packingMetadata.reassociations);
341 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
342 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
345 DBGS() <<
"reshape op: " << reshapeOp;
DBGSNL();
346 llvm::interleaveComma(transpPerm,
DBGS() <<
"transpPerm: ");
350 rewriter.
replaceOp(packOp, transposeOp->getResults());
356 tensor::UnPackOp unPackOp) {
358 if (!unPackOp.getOuterDimsPerm().empty() &&
361 "non-identity outer dims perm NYI");
368 RankedTensorType packedTensorType = unPackOp.getSourceType();
369 int64_t packedRank = packedTensorType.getRank();
372 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
373 if (unPackOp.isLikeUnPad()) {
382 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
383 loc, destTensorType, unPackOp.getSource(),
387 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
390 nullptr, extractSliceOp};
394 int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
395 auto lastDims = llvm::to_vector(
396 llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
397 PackingMetadata packingMetadata =
398 computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
400 packedRank, lastDims, packingMetadata.insertPositions);
408 RankedTensorType stripMinedTensorType =
410 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
411 stripMinedTensorType, packingMetadata.reassociations);
418 auto emptyOp = rewriter.
create<tensor::EmptyOp>(
419 loc, dims, stripMinedTensorType.getElementType());
420 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
421 loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
424 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
425 DBGS() <<
"insertPositions: ");
426 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
427 DBGS() <<
"packedShape: ");
429 llvm::interleaveComma(lastDimsToInsertPositionsPerm,
430 DBGS() <<
"lastDimsToInsertPositionsPerm: ");
431 DBGSNL(); llvm::interleaveComma(
432 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
434 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
437 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
441 auto reshapeOp = rewriter.
create<tensor::CollapseShapeOp>(
442 loc, collapsedType, transposeOp->getResult(0),
443 packingMetadata.reassociations);
446 int64_t destRank = destTensorType.getRank();
447 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
448 loc, destTensorType, reshapeOp->getResult(0),
454 auto copyOp = rewriter.
create<linalg::CopyOp>(
455 loc, extractSliceOp->getResult(0), unPackOp.getDest());
458 rewriter.
replaceOp(unPackOp, copyOp->getResults());
464 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
466 for (
auto &i : spec) {
467 if (!i.packedDimForEachOperand[operandPos].has_value())
469 res.push_back(i.packedDimForEachOperand[operandPos].value());
475 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
477 for (
auto &i : spec) {
478 if (!i.packedDimForEachOperand[operandPos].has_value())
480 res.push_back(i.packedSize);
489 linalg::LinalgOp linalgOp,
491 if (packedSizes.size() != linalgOp.getNumLoops()) {
493 "incorrect number of pack sizes");
499 linalgOp.getIteratorTypesArray();
500 LLVM_DEBUG(
DBGS() <<
"Start packing: " << linalgOp <<
"\n";
501 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
502 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
508 PackedOperandsDimList listOfPackedOperandsDim;
509 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
512 if (maybeConstant.has_value() && maybeConstant.value() == 0)
515 PackedOperandsDim packedOperandsDims;
516 packedOperandsDims.packedSize = packedSizes[i];
518 maybePackedDimForEachOperand =
520 if (
failed(maybePackedDimForEachOperand))
522 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
523 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
526 DBGS() <<
"++++ After pack size #" << i <<
": " << packedSizes[i]
528 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
529 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
DBGSNL();
530 llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
531 DBGS() <<
"packedDimForEachOperand: ");
538 linalgOp.getDpsInitsMutable(), [](
OpOperand &o) { return &o; }));
540 for (
const auto &operandsList : {inputOperands, initOperands}) {
541 for (
OpOperand *opOperand : operandsList) {
542 int64_t pos = opOperand->getOperandNumber();
543 Value operand = opOperand->get();
545 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
547 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
549 DBGS() <<
"operand: " << operand <<
"\n";
550 llvm::interleaveComma(innerPos,
DBGS() <<
"innerPos: ");
DBGSNL();
551 llvm::interleaveComma(innerPackSizes,
DBGS() <<
"innerPackSizes: ");
553 if (innerPackSizes.empty()) {
554 inputsAndInits.push_back(operand);
557 Value dest = tensor::PackOp::createDestinationTensor(
558 rewriter, loc, operand, innerPackSizes, innerPos,
560 ShapedType operandType = cast<ShapedType>(operand.
getType());
561 bool areConstantTiles =
565 if (areConstantTiles && operandType.hasStaticShape() &&
566 !tensor::PackOp::requirePaddingValue(
567 operandType.getShape(), innerPos,
568 cast<ShapedType>(dest.
getType()).getShape(), {},
570 packOps.push_back(rewriter.
create<tensor::PackOp>(
571 loc, operand, dest, innerPos, innerPackSizes));
577 Value zero = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
578 packOps.push_back(rewriter.
create<tensor::PackOp>(
579 loc, operand, dest, innerPos, innerPackSizes, zero));
581 inputsAndInits.push_back(packOps.back());
587 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
589 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
590 auto packedLinalgOp = rewriter.
create<linalg::GenericOp>(
591 linalgOp.getLoc(), inits.
getTypes(), inputs, inits, indexingMaps,
596 for (
OpResult result : packedLinalgOp->getResults()) {
597 int64_t resultNum = result.getResultNumber();
598 tensor::PackOp maybePackedInit =
599 inits[resultNum].getDefiningOp<tensor::PackOp>();
600 if (!maybePackedInit) {
601 results.push_back(result);
605 unPackOps.push_back(rewriter.
create<tensor::UnPackOp>(
606 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
607 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
608 results.push_back(unPackOps.back());
616 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
645 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
649 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
651 assert(tensorType == transposedValue.
getType() &&
652 "expected tensor type mismatch");
657 llvm::map_range(permutation, [](int64_t i) ->
unsigned {
return i; }));
661 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
665 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
671 auto transposedGenericOp = rewriter.
create<linalg::GenericOp>(
674 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
675 operandsRef.take_front(linalgOp.getNumDpsInputs()),
676 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
678 linalgOp.getIteratorTypesArray());
680 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
682 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
687 linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
694 tensor::PackOp transposedPackOp =
695 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
697 if (!packOp.getResult().hasOneUse())
700 OpOperand &packUse = *packOp->getUses().begin();
701 if (packUse.
getOwner() != linalgOp) {
703 linalgOp,
"not a single use by the LinalgOp target");
706 (!linalgOp.isDpsInit(&packUse) ||
707 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
709 "not produced by the LinalgOp target");
715 int64_t numLeadingDims = packOp.getSourceRank();
716 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
720 if (permutation.empty())
721 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
723 if (innerPerm.empty()) {
726 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
728 llvm::append_range(permutation,
729 llvm::map_range(innerPerm, [&](int64_t pos) {
730 return numLeadingDims + pos;
742 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
745 tensor::UnPackOp transposedUnPackOp;
748 transposedLinalgOp->getOpOperand(packUseOperandNumber);
749 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
751 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
752 rewriter, loc, transposedResult, innerPerm, outerPerm);
754 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
758 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
781 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
782 assert((mnkPaddedSizesNextMultipleOf.empty() ||
783 mnkPaddedSizesNextMultipleOf.size() == 3) &&
784 "num of packing sizes next multiple should be empty or of size 3");
785 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
788 int64_t numLoops = linalgOp.getNumLoops();
790 LLVM_DEBUG(
DBGS() <<
"need 3+ loops to find a matmul to pack, got "
791 << numLoops <<
"\nin: " << linalgOp <<
"\n");
793 linalgOp,
"need 3+ loops to find a matmul to pack");
797 int64_t numPackedDims = mnkPackedSizes.size();
799 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
800 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
802 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
803 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
805 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
806 paddedSizesNextMultipleOf[mnkOrder[i]] =
807 mnkPaddedSizesNextMultipleOf.empty() ? 0
808 : mnkPaddedSizesNextMultipleOf[i];
814 if (
failed(maybeDimensions)) {
815 LLVM_DEBUG(
DBGS() <<
"couldn't infer matmul iterators in: " << linalgOp
818 "couldn't infer matmul iterators");
826 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
827 kPos = maybeDimensions->k.back();
829 DBGS() <<
"Start packing generic op greedily with (m@" << mPos
830 <<
", n@" << nPos <<
", k@" << kPos <<
"): " << linalgOp
834 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
838 assert(
succeeded(generalizeResult) &&
"unexpected failure generalizing op");
839 genericOp = *generalizeResult;
847 LLVM_DEBUG(llvm::interleaveComma(permutation,
DBGS() <<
"perm: ");
DBGSNL(););
852 assert(
succeeded(interchangeResult) &&
"unexpected failure interchanging op");
853 genericOp = *interchangeResult;
854 LLVM_DEBUG(
DBGS() <<
"Generalized Op to pack: " << genericOp <<
"\n";);
871 cast<LinalgOp>(genericOp.getOperation())
872 .createLoopRanges(rewriter, genericOp.getLoc());
876 LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
877 DBGS() <<
"paddedSizesNextMultipleOf: ");
879 LLVM_DEBUG(llvm::interleaveComma(loopRanges,
DBGS() <<
"loopRanges: ",
880 [](
Range r) { llvm::dbgs() << r.
size; });
884 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
885 if (paddedSizesNextMultipleOf[i] == 0) {
886 adjustedPackedSizes.push_back(packedSizes[i]);
893 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
894 {loopRanges[adjustedPackedSizes.size()].size,
895 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
897 LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
898 DBGS() <<
"adjustedPackedSizes: ");
905 return pack(rewriter, genericOp, adjustedPackedSizes);
914 assert(!tileSizeComputationFunction &&
"tile sizes already set");
920 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
938 auto padValue = padOp.getConstantPaddingValue();
940 return rewriter.
create<FillOp>(padOp.getLoc(), padValue, dest).result();
943 auto generateOp = rewriter.
create<tensor::GenerateOp>(
944 padOp.getLoc(), padOp.getResultType(), dynSizes);
947 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
956 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
960 padOp.getLoc(), cast<IntegerAttr>(ofr.get<
Attribute>()).getInt())
964 auto resultType = padOp.getResultType();
968 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
969 if (resultType.isDynamicDim(dim)) {
971 padOp.getSource(), dim));
974 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
976 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
977 dynSizes.push_back(plusHigh);
979 staticSizes.push_back(resultType.getDimSize(dim));
983 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
984 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
993 auto sourceType = padOp.getSourceType();
1001 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
1009 if (!sliceOp.hasUnitStride())
1012 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1016 bool zeroSliceGuard =
true;
1018 if (std::optional<bool> control = controlFn(sliceOp))
1019 zeroSliceGuard = *control;
1026 sliceOp.getMixedSizes(), zeroSliceGuard);
1027 if (
failed(tilingResult))
1031 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1038 tensor::PackOp packOp) {
1039 Value input = packOp.getSource();
1040 if (!packOp.getPaddingValue()) {
1045 ShapedType inputType = packOp.getSourceType();
1046 int64_t inputRank = inputType.getRank();
1047 assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
1048 [](int64_t val) { return val == 1; }));
1052 packOp.getDimAndTileMapping();
1053 for (int64_t dim = 0; dim < inputRank; ++dim) {
1054 int64_t size = inputType.getDimSize(dim);
1055 if (!tileAndPosMapping.count(dim)) {
1056 paddedShape.push_back(size);
1061 std::optional<int64_t> tileSize =
1063 assert(tileSize.has_value() &&
"dynamic inner tile size is not supported");
1064 paddedShape.push_back(tileSize.value());
1069 false, loc, builder);
1078 constexpr int64_t kNonTiledMarker = -1;
1083 vec, [&](int64_t v) {
return v != kNonTiledMarker; }));
1098 int64_t unpackedRank = shape.size();
1099 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1100 if (llvm::is_contained(innerDimsPos, i)) {
1101 innerDims.push_back(dim++);
1106 outerDims.push_back(dim++);
1107 if (!outerDimsPerm.empty())
1108 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1114 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1119 rankReducedOuterDimsPerm =
1121 if (!rankReducedOuterDimsPerm.empty())
1122 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1125 perm.append(innerDims);
1132 if (llvm::any_of(packOp.getMixedTiles(),
1135 "require inner tile sizes being static");
1140 auto innerDimsPos = packOp.getInnerDimsPos();
1141 int64_t srcRank = packOp.getSourceRank();
1142 auto destShape = packOp.getDestType().getShape();
1143 if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
1144 return destShape[index] != 1;
1147 packOp,
"require the tiled outer dimensions of the result are all 1s");
1154 auto inputShape = packOp.getSourceType().getShape();
1156 packOp.getDimAndTileMapping();
1163 for (
auto i : llvm::seq<unsigned>(0, srcRank)) {
1164 if (dimAndTileMapping.count(i)) {
1166 .value_or(ShapedType::kDynamic));
1167 readSizes.push_back(dimAndTileMapping[i]);
1170 if (ShapedType::isDynamic(inputShape[i])) {
1171 readSizes.push_back(
1174 readSizes.push_back(rewriter.
getIndexAttr(inputShape[i]));
1176 if (inputShape[i] != 1)
1177 readShape.push_back(inputShape[i]);
1180 Type elemType = packOp.getSourceType().getElementType();
1184 loc, readType, input, readOffsets, readSizes, readStrides);
1189 inputShape, innerDimsPos, packOp.getOuterDimsPerm());
1191 LLVM_DEBUG(
DBGS() <<
"Pack permutation: " << packOp <<
"\n";
1192 llvm::interleaveComma(perm,
DBGS() <<
"perm: ");
DBGSNL(););
1195 applyPermutationToVector<int64_t>(transpShape, perm);
1197 Value empty = rewriter.
create<tensor::EmptyOp>(loc, transpShape, elemType);
1199 rewriter.
create<linalg::TransposeOp>(loc,
tile, empty, perm);
1202 int64_t destRank = packOp.getDestRank();
1208 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1209 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1210 writeSizes, writeStrides);
1211 rewriter.
replaceOp(packOp, insert.getResult());
1218 int64_t srcRank = unpackOp.getSourceRank();
1219 int64_t destRank = unpackOp.getDestRank();
1222 if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
1223 return srcShape[index] != 1;
1227 "require the tiled outer dimensions of the result are all 1s");
1232 Value source = unpackOp.getSource();
1234 unpackOp.getDimAndTileMapping();
1242 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1243 if (dimAndTileMapping.count(i)) {
1244 readSizes.push_back(oneIdxAttr);
1248 if (ShapedType::isDynamic(srcShape[i])) {
1250 rewriter.
create<tensor::DimOp>(loc, source, i).getResult();
1251 readSizes.push_back(dynamicDim);
1252 dynamicDims.push_back(dynamicDim);
1254 readSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1256 if (srcShape[i] != 1)
1257 readShape.push_back(srcShape[i]);
1259 auto mixedTiles = unpackOp.getMixedTiles();
1260 readSizes.append(mixedTiles.begin(), mixedTiles.end());
1264 auto tileShape = srcShape.drop_front(destRank);
1266 readShape.append(tileShape.begin(), tileShape.end());
1267 Type elemType = unpackOp.getSourceType().getElementType();
1269 Value innerTile = rewriter.
create<tensor::ExtractSliceOp>(
1270 loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
1274 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1278 applyPermutationToVector<int64_t>(transpShape, perm);
1281 rewriter.
create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
1283 rewriter.
create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1287 int numLoops = transpShape.size();
1292 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1293 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1294 tileSizes.push_back(
1298 auto partialTile = rewriter.
create<tensor::ExtractSliceOp>(
1299 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1305 for (
int i = 0, idx = 0; i < destRank; ++i) {
1306 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1307 writeSizes.push_back(tileSizes[idx++]);
1309 writeSizes.push_back(oneIdxAttr);
1311 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1312 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1314 rewriter.
replaceOp(unpackOp, insert.getResult());
1327 template <
typename Conv2DOp,
typename Conv1DOp>
1330 if (convOp.hasPureBufferSemantics())
1333 Value input = convOp.getInputs().front();
1334 Value kernel = convOp.getInputs().back();
1335 Value output = convOp.getOutputs().front();
1337 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1338 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1339 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1341 auto kernelShape = kernelType.getShape();
1342 auto outputShape = outputType.getShape();
1345 auto [khIndex, kwIndex, ohIndex, owIndex] =
1348 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1349 return std::make_tuple(0, 1, 1, 2);
1351 .Case([&](linalg::Conv2DNchwFchwOp op) {
1352 return std::make_tuple(2, 3, 2, 3);
1354 .Case([&](linalg::PoolingNhwcSumOp op) {
1355 return std::make_tuple(0, 1, 1, 2);
1357 .Case([&](linalg::PoolingNchwSumOp op) {
1358 return std::make_tuple(0, 1, 2, 3);
1360 .Case([&](linalg::PoolingNhwcMaxOp op) {
1361 return std::make_tuple(0, 1, 1, 2);
1363 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1364 return std::make_tuple(0, 1, 1, 2);
1366 .Case([&](linalg::PoolingNhwcMinOp op) {
1367 return std::make_tuple(0, 1, 1, 2);
1369 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1370 return std::make_tuple(0, 1, 1, 2);
1372 .Case([&](linalg::PoolingNchwMaxOp op) {
1373 return std::make_tuple(0, 1, 2, 3);
1376 llvm_unreachable(
"unexpected conv2d/pool2d operation.");
1377 return std::make_tuple(0, 0, 0, 0);
1382 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1383 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1384 bool removeH = (khSize == 1 && ohSize == 1);
1385 bool removeW = (kwSize == 1 && owSize == 1);
1386 if (!removeH && !removeW)
1392 RankedTensorType newInputType =
1393 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1394 RankedTensorType newKernelType =
1395 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1396 RankedTensorType newOutputType =
1397 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1402 rewriter, loc, input, newInputType);
1404 rewriter, loc, kernel, newKernelType);
1406 rewriter, loc, output, newOutputType);
1411 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1412 strides.erase(strides.begin() + (removeH ? 0 : 1));
1416 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1417 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1420 auto conv1DOp = rewriter.
create<Conv1DOp>(
1421 loc, newOutputType,
ValueRange{newInput, newKernel},
1422 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1426 rewriter, loc, conv1DOp.getResult(0), output);
1443 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1447 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1454 if (convOp.hasPureBufferSemantics())
1457 Value input = convOp.getInputs().front();
1458 Value kernel = convOp.getInputs().back();
1459 Value output = convOp.getOutputs().front();
1461 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1462 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1463 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1465 auto kernelShape = kernelType.getShape();
1466 auto outputShape = outputType.getShape();
1470 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1471 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1472 bool removeH = (khSize == 1 && ohSize == 1);
1473 bool removeW = (kwSize == 1 && owSize == 1);
1474 if (!removeH && !removeW)
1480 RankedTensorType newInputType =
1481 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1482 RankedTensorType newKernelType =
1483 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1484 RankedTensorType newOutputType =
1485 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1490 rewriter, loc, input, newInputType);
1492 rewriter, loc, kernel, newKernelType);
1494 rewriter, loc, output, newOutputType);
1498 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1499 strides.erase(strides.begin() + (removeH ? 0 : 1));
1503 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1504 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1507 auto conv1DOp = rewriter.
create<DepthwiseConv1DNwcWcOp>(
1508 loc, newOutputType,
ValueRange{newInput, newKernel},
1509 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1513 rewriter, loc, conv1DOp.getResult(0), output);
1522 if (convOp.hasPureBufferSemantics())
1525 Value input = convOp.getInputs().front();
1526 Value kernel = convOp.getInputs().back();
1527 Value output = convOp.getOutputs().front();
1529 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1530 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1531 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1533 auto kernelShape = kernelType.getShape();
1534 auto outputShape = outputType.getShape();
1538 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1539 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1540 bool removeH = (khSize == 1 && ohSize == 1);
1541 bool removeW = (kwSize == 1 && owSize == 1);
1542 if (!removeH && !removeW)
1548 RankedTensorType newInputType =
1549 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1550 RankedTensorType newKernelType =
1551 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1552 RankedTensorType newOutputType =
1553 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1558 rewriter, loc, input, newInputType);
1560 rewriter, loc, kernel, newKernelType);
1562 rewriter, loc, output, newOutputType);
1564 auto conv1DOp = rewriter.
create<Conv1DOp>(loc, newOutputType,
1570 rewriter, loc, conv1DOp.getResult(0), output);
1589 PoolingNwcMaxUnsignedOp>,
1592 PoolingNwcMinUnsignedOp>,
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const override
OptimizeCopyFn optimizeCopyFn
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.