33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
40 #define DEBUG_TYPE "linalg-transforms"
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 #define DBGSNL() (llvm::dbgs() << "\n")
62 .Case<scf::ForOp>([&](scf::ForOp forOp) {
63 scf::ForOp partialIteration;
66 return partialIteration->getResults();
67 assert(!partialIteration &&
"expected that loop was not peeled");
68 return forOp->getResults();
77 for (
auto loopOp : loops)
90 if (!e.isFunctionOfDim(dim))
152 int64_t newDim = iteratorTypes.size();
153 iteratorTypes.push_back(iteratorTypes[dim]);
156 indexingMaps.size(), std::nullopt);
158 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
160 AffineMap map = indexingMaps[operandIdx];
163 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
171 "num results invariant violation");
173 if (!maybeOperandDimensionToPack.has_value()) {
174 newMaps.push_back(map);
179 if (!map.
getResult(maybeOperandDimensionToPack.value())
186 newMaps.push_back(map);
189 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
191 indexingMaps = newMaps;
193 return packedDimPerIndexingMap;
199 struct PackedOperandsDim {
205 struct PackedOperandsDimList {
206 void push_back(PackedOperandsDim &&packedOperandsDims) {
207 spec.emplace_back(packedOperandsDims);
221 tensor::PackOp packOp) {
223 auto packedTensorType =
224 cast<RankedTensorType>(packOp->getResultTypes().front());
225 if (llvm::any_of(packOp.getStaticInnerTiles(),
226 [](int64_t size) { return ShapedType::isDynamic(size); })) {
229 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
244 int64_t numPackedDims = packOp.getInnerDimsPos().size();
245 int64_t packedRank = packedTensorType.getRank();
246 auto lastDims = llvm::to_vector(
247 llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
248 PackingMetadata packingMetadata = computePackingMetadata(
249 packedTensorType.getRank(), packOp.getInnerDimsPos());
251 packedRank, lastDims, packingMetadata.insertPositions);
255 if (!outerPerm.empty())
258 packedRank, packingMetadata.outerPositions, outerPos);
273 for (
auto [pos, innerSize] :
274 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
276 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
286 rewriter, loc, map, {outerSize, origSize, innerSize});
288 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
290 packingMetadata.reassociations);
291 Value paddingValue = packOp.getPaddingValue();
293 paddingValue = rewriter.
create<arith::ConstantOp>(
297 rewriter.
create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
298 highs, paddingValue,
false);
301 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
302 DBGS() <<
"insertPositions: ");
303 DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
304 DBGS() <<
"outerPositions: ");
305 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
306 DBGS() <<
"packedShape: ");
308 llvm::interleaveComma(outerPositionPerm,
DBGS() <<
"outerPositionPerm: ");
309 DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
310 DBGS() <<
"innerPositionsPerm: ");
312 llvm::interleaveComma(packedToStripMinedShapePerm,
313 DBGS() <<
"packedToStripMinedShapePerm: ");
314 DBGSNL(); llvm::interleaveComma(
315 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
317 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
320 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
323 if (packOp.isLikePad()) {
342 auto insertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
347 LLVM_DEBUG(
DBGS() <<
"insert_slice op: " << insertSliceOp;
DBGSNL(););
349 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
356 auto reshapeOp = rewriter.
create<tensor::ExpandShapeOp>(
359 padOp.getResult(), packingMetadata.reassociations);
364 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
365 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
368 DBGS() <<
"reshape op: " << reshapeOp;
DBGSNL();
369 llvm::interleaveComma(transpPerm,
DBGS() <<
"transpPerm: ");
373 rewriter.
replaceOp(packOp, transposeOp->getResults());
379 tensor::UnPackOp unPackOp) {
381 if (!unPackOp.getOuterDimsPerm().empty())
384 RankedTensorType packedTensorType = unPackOp.getSourceType();
385 if (!packedTensorType.hasStaticShape()) {
388 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
395 int64_t packedRank = packedTensorType.getRank();
398 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
399 if (unPackOp.isLikeUnPad()) {
408 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
409 loc, destTensorType, unPackOp.getSource(),
413 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
416 nullptr, extractSliceOp};
420 int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
421 auto lastDims = llvm::to_vector(
422 llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
423 PackingMetadata packingMetadata =
424 computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
426 packedRank, lastDims, packingMetadata.insertPositions);
434 RankedTensorType stripMinedTensorType =
436 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
437 stripMinedTensorType, packingMetadata.reassociations);
440 auto transposeOp = rewriter.
create<linalg::TransposeOp>(
441 loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
444 DBGSNL();
DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
445 DBGS() <<
"insertPositions: ");
446 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
447 DBGS() <<
"packedShape: ");
449 llvm::interleaveComma(lastDimsToInsertPositionsPerm,
450 DBGS() <<
"lastDimsToInsertPositionsPerm: ");
451 DBGSNL(); llvm::interleaveComma(
452 packingMetadata.reassociations,
DBGS() <<
"reassociations: ",
454 llvm::interleaveComma(ri, llvm::dbgs() <<
"|");
457 llvm::interleaveComma(stripMinedShape,
DBGS() <<
"stripMinedShape: ");
461 auto reshapeOp = rewriter.
create<tensor::CollapseShapeOp>(
462 loc, collapsedType, transposeOp->getResult(0),
463 packingMetadata.reassociations);
466 int64_t destRank = destTensorType.getRank();
467 auto extractSliceOp = rewriter.
create<tensor::ExtractSliceOp>(
468 loc, destTensorType, reshapeOp->getResult(0),
474 auto copyOp = rewriter.
create<linalg::CopyOp>(
475 loc, extractSliceOp->getResult(0), unPackOp.getDest());
478 rewriter.
replaceOp(unPackOp, copyOp->getResults());
484 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
486 for (int64_t i = 0, e = spec.size(); i < e; ++i) {
487 if (!spec[i].packedDimForEachOperand[operandPos].has_value())
489 res.push_back(spec[i].packedDimForEachOperand[operandPos].value());
495 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
497 for (int64_t i = 0, e = spec.size(); i < e; ++i) {
498 if (!spec[i].packedDimForEachOperand[operandPos].has_value())
500 res.push_back(spec[i].packedSize);
509 linalg::LinalgOp linalgOp,
511 if (packedSizes.size() != linalgOp.getNumLoops()) {
513 "incorrect number of pack sizes");
519 linalgOp.getIteratorTypesArray();
520 LLVM_DEBUG(
DBGS() <<
"Start packing: " << linalgOp <<
"\n";
521 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
522 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
528 PackedOperandsDimList listOfPackedOperandsDim;
529 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
532 if (maybeConstant.has_value() && maybeConstant.value() == 0)
535 PackedOperandsDim packedOperandsDims;
536 packedOperandsDims.packedSize = packedSizes[i];
538 maybePackedDimForEachOperand =
540 if (
failed(maybePackedDimForEachOperand))
542 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
543 listOfPackedOperandsDim.push_back(std::move(packedOperandsDims));
546 DBGS() <<
"++++ After pack size #" << i <<
": " << packedSizes[i]
548 llvm::interleaveComma(indexingMaps,
DBGS() <<
"maps: ");
DBGSNL();
549 llvm::interleaveComma(iteratorTypes,
DBGS() <<
"iterators: ");
DBGSNL();
550 llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
551 DBGS() <<
"packedDimForEachOperand: ");
558 linalgOp.getDpsInitsMutable(), [](
OpOperand &o) { return &o; }));
560 for (
const auto &operandsList : {inputOperands, initOperands}) {
561 for (
OpOperand *opOperand : operandsList) {
562 int64_t pos = opOperand->getOperandNumber();
563 Value operand = opOperand->get();
565 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
567 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
569 DBGS() <<
"operand: " << operand <<
"\n";
570 llvm::interleaveComma(innerPos,
DBGS() <<
"innerPos: ");
DBGSNL();
571 llvm::interleaveComma(innerPackSizes,
DBGS() <<
"innerPackSizes: ");
573 if (innerPackSizes.empty()) {
574 inputsAndInits.push_back(operand);
577 Value dest = tensor::PackOp::createDestinationTensor(
578 rewriter, loc, operand, innerPackSizes, innerPos,
580 ShapedType operandType = operand.
getType().
cast<ShapedType>();
581 bool areConstantTiles =
585 if (areConstantTiles && operandType.hasStaticShape() &&
586 !tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos,
588 packOps.push_back(rewriter.
create<tensor::PackOp>(
589 loc, operand, dest, innerPos, innerPackSizes));
595 Value zero = rewriter.
create<arith::ConstantOp>(loc, zeroAttr);
596 packOps.push_back(rewriter.
create<tensor::PackOp>(
597 loc, operand, dest, innerPos, innerPackSizes, zero));
599 inputsAndInits.push_back(packOps.back());
605 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
607 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
608 auto packedLinalgOp = rewriter.
create<linalg::GenericOp>(
609 linalgOp.getLoc(), inits.
getTypes(), inputs, inits, indexingMaps,
614 for (
OpResult result : packedLinalgOp->getResults()) {
615 int64_t resultNum = result.getResultNumber();
616 tensor::PackOp maybePackedInit =
617 inits[resultNum].getDefiningOp<tensor::PackOp>();
618 if (!maybePackedInit) {
619 results.push_back(result);
623 unPackOps.push_back(rewriter.
create<tensor::UnPackOp>(
624 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
625 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
626 results.push_back(unPackOps.back());
634 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
663 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
667 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
669 assert(tensorType == transposedValue.
getType() &&
670 "expected tensor type mismatch");
675 llvm::map_range(permutation, [](int64_t i) ->
unsigned {
return i; }));
679 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
683 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
689 auto transposedGenericOp = rewriter.
create<linalg::GenericOp>(
692 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
693 operandsRef.take_front(linalgOp.getNumDpsInputs()),
694 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
696 linalgOp.getIteratorTypesArray());
698 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
700 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
705 linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
712 tensor::PackOp transposedPackOp =
713 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
715 if (!packOp.getResult().hasOneUse())
718 OpOperand &packUse = *packOp->getUses().begin();
719 if (packUse.
getOwner() != linalgOp) {
721 linalgOp,
"not a single use by the LinalgOp target");
724 (!linalgOp.isDpsInit(&packUse) ||
725 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
727 "not produced by the LinalgOp target");
733 int64_t numLeadingDims = packOp.getSourceRank();
734 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
738 if (permutation.empty())
739 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
741 if (innerPerm.empty()) {
744 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
746 llvm::append_range(permutation,
747 llvm::map_range(innerPerm, [&](int64_t pos) {
748 return numLeadingDims + pos;
760 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
763 tensor::UnPackOp transposedUnPackOp;
766 transposedLinalgOp->getOpOperand(packUseOperandNumber);
767 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
769 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
770 rewriter, loc, transposedResult, innerPerm, outerPerm);
772 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
776 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
799 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
800 assert((mnkPaddedSizesNextMultipleOf.empty() ||
801 mnkPaddedSizesNextMultipleOf.size() == 3) &&
802 "num of packing sizes next multiple should be empty or of size 3");
803 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
806 int64_t numLoops = linalgOp.getNumLoops();
808 LLVM_DEBUG(
DBGS() <<
"need 3+ loops to find a matmul to pack, got "
809 << numLoops <<
"\nin: " << linalgOp <<
"\n");
811 linalgOp,
"need 3+ loops to find a matmul to pack");
815 int64_t numPackedDims = mnkPackedSizes.size();
817 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
818 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
820 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
821 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
823 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
824 paddedSizesNextMultipleOf[mnkOrder[i]] =
825 mnkPaddedSizesNextMultipleOf.empty() ? 0
826 : mnkPaddedSizesNextMultipleOf[i];
832 if (
failed(maybeDimensions)) {
833 LLVM_DEBUG(
DBGS() <<
"couldn't infer matmul iterators in: " << linalgOp
836 "couldn't infer matmul iterators");
844 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
845 kPos = maybeDimensions->k.back();
847 DBGS() <<
"Start packing generic op greedily with (m@" << mPos
848 <<
", n@" << nPos <<
", k@" << kPos <<
"): " << linalgOp
852 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
856 assert(
succeeded(generalizeResult) &&
"unexpected failure generalizing op");
857 genericOp = *generalizeResult;
865 LLVM_DEBUG(llvm::interleaveComma(permutation,
DBGS() <<
"perm: ");
DBGSNL(););
870 assert(
succeeded(interchangeResult) &&
"unexpected failure interchanging op");
871 genericOp = *interchangeResult;
872 LLVM_DEBUG(
DBGS() <<
"Generalized Op to pack: " << genericOp <<
"\n";);
889 cast<LinalgOp>(genericOp.getOperation())
890 .createLoopRanges(rewriter, genericOp.getLoc());
894 LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
895 DBGS() <<
"paddedSizesNextMultipleOf: ");
897 LLVM_DEBUG(llvm::interleaveComma(loopRanges,
DBGS() <<
"loopRanges: ",
898 [](
Range r) { llvm::dbgs() << r.
size; });
902 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
903 if (paddedSizesNextMultipleOf[i] == 0) {
904 adjustedPackedSizes.push_back(packedSizes[i]);
911 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
912 {loopRanges[adjustedPackedSizes.size()].size,
913 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
915 LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
916 DBGS() <<
"adjustedPackedSizes: ");
923 return pack(rewriter, genericOp, adjustedPackedSizes);
932 assert(!tileSizeComputationFunction &&
"tile sizes already set");
938 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
956 auto padValue = padOp.getConstantPaddingValue();
958 return rewriter.
create<FillOp>(padOp.getLoc(), padValue, dest).result();
961 auto generateOp = rewriter.
create<tensor::GenerateOp>(
962 padOp.getLoc(), padOp.getResultType(), dynSizes);
965 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
974 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
978 padOp.getLoc(), cast<IntegerAttr>(ofr.get<
Attribute>()).getInt())
982 auto resultType = padOp.getResultType();
986 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
987 if (resultType.isDynamicDim(dim)) {
989 padOp.getSource(), dim));
992 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
994 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
995 dynSizes.push_back(plusHigh);
997 staticSizes.push_back(resultType.getDimSize(dim));
1001 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
1002 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
1011 auto sourceType = padOp.getSourceType();
1019 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
1027 if (!sliceOp.hasUnitStride())
1030 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1034 bool zeroSliceGuard =
true;
1036 if (std::optional<bool> control = controlFn(sliceOp))
1037 zeroSliceGuard = *control;
1044 sliceOp.getMixedSizes(), zeroSliceGuard);
1045 if (
failed(tilingResult))
1049 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1056 tensor::PackOp packOp) {
1057 Value input = packOp.getSource();
1058 if (!packOp.getPaddingValue()) {
1063 ShapedType inputType = packOp.getSourceType();
1064 int64_t inputRank = inputType.getRank();
1065 assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
1066 [](int64_t val) { return val == 1; }));
1070 packOp.getDimAndTileMapping();
1071 for (int64_t dim = 0; dim < inputRank; ++dim) {
1072 int64_t size = inputType.getDimSize(dim);
1073 if (!tileAndPosMapping.count(dim)) {
1074 paddedShape.push_back(size);
1079 std::optional<int64_t> tileSize =
1081 assert(tileSize.has_value() &&
"dynamic inner tile size is not supported");
1082 paddedShape.push_back(tileSize.value());
1087 false, loc, builder);
1096 constexpr int64_t kNonTiledMarker = -1;
1101 vec, [&](int64_t v) {
return v != kNonTiledMarker; }));
1116 int64_t unpackedRank = shape.size();
1117 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1118 if (llvm::is_contained(innerDimsPos, i)) {
1119 innerDims.push_back(dim++);
1124 outerDims.push_back(dim++);
1125 if (!outerDimsPerm.empty())
1126 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1132 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1137 rankReducedOuterDimsPerm =
1139 if (!rankReducedOuterDimsPerm.empty())
1140 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1143 perm.append(innerDims);
1150 if (llvm::any_of(packOp.getMixedTiles(),
1153 "require inner tile sizes being static");
1158 auto innerDimsPos = packOp.getInnerDimsPos();
1159 int64_t srcRank = packOp.getSourceRank();
1160 auto destShape = packOp.getDestType().getShape();
1161 if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
1162 return destShape[index] != 1;
1165 packOp,
"require the tiled outer dimensions of the result are all 1s");
1172 auto inputShape = packOp.getSourceType().getShape();
1174 packOp.getDimAndTileMapping();
1181 for (
auto i : llvm::seq<unsigned>(0, srcRank)) {
1182 if (dimAndTileMapping.count(i)) {
1184 .value_or(ShapedType::kDynamic));
1185 readSizes.push_back(dimAndTileMapping[i]);
1188 if (ShapedType::isDynamic(inputShape[i])) {
1189 readSizes.push_back(
1192 readSizes.push_back(rewriter.
getIndexAttr(inputShape[i]));
1194 if (inputShape[i] != 1)
1195 readShape.push_back(inputShape[i]);
1198 Type elemType = packOp.getSourceType().getElementType();
1202 loc, readType, input, readOffsets, readSizes, readStrides);
1207 inputShape, innerDimsPos, packOp.getOuterDimsPerm());
1209 LLVM_DEBUG(
DBGS() <<
"Pack permutation: " << packOp <<
"\n";
1210 llvm::interleaveComma(perm,
DBGS() <<
"perm: ");
DBGSNL(););
1213 applyPermutationToVector<int64_t>(transpShape, perm);
1215 Value empty = rewriter.
create<tensor::EmptyOp>(loc, transpShape, elemType);
1217 rewriter.
create<linalg::TransposeOp>(loc,
tile, empty, perm);
1220 int64_t destRank = packOp.getDestRank();
1226 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1227 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1228 writeSizes, writeStrides);
1229 rewriter.
replaceOp(packOp, insert.getResult());
1236 int64_t srcRank = unpackOp.getSourceRank();
1237 int64_t destRank = unpackOp.getDestRank();
1240 if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
1241 return srcShape[index] != 1;
1245 "require the tiled outer dimensions of the result are all 1s");
1250 Value source = unpackOp.getSource();
1252 unpackOp.getDimAndTileMapping();
1260 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1261 if (dimAndTileMapping.count(i)) {
1262 readSizes.push_back(oneIdxAttr);
1266 if (ShapedType::isDynamic(srcShape[i])) {
1268 rewriter.
create<tensor::DimOp>(loc, source, i).getResult();
1269 readSizes.push_back(dynamicDim);
1270 dynamicDims.push_back(dynamicDim);
1272 readSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1274 if (srcShape[i] != 1)
1275 readShape.push_back(srcShape[i]);
1277 auto mixedTiles = unpackOp.getMixedTiles();
1278 readSizes.append(mixedTiles.begin(), mixedTiles.end());
1282 auto tileShape = srcShape.drop_front(destRank);
1284 readShape.append(tileShape.begin(), tileShape.end());
1285 Type elemType = unpackOp.getSourceType().getElementType();
1287 Value innerTile = rewriter.
create<tensor::ExtractSliceOp>(
1288 loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
1292 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1296 applyPermutationToVector<int64_t>(transpShape, perm);
1299 rewriter.
create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
1301 rewriter.
create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1305 int numLoops = transpShape.size();
1310 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1311 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1312 tileSizes.push_back(
1316 auto partialTile = rewriter.
create<tensor::ExtractSliceOp>(
1317 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1323 for (
int i = 0, idx = 0; i < destRank; ++i) {
1324 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1325 writeSizes.push_back(tileSizes[idx++]);
1327 writeSizes.push_back(oneIdxAttr);
1329 auto insert = rewriter.
create<tensor::InsertSliceOp>(
1330 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1332 rewriter.
replaceOp(unpackOp, insert.getResult());
1345 template <
typename Conv2DOp,
typename Conv1DOp>
1348 if (convOp.hasBufferSemantics())
1351 Value input = convOp.getInputs().front();
1352 Value kernel = convOp.getInputs().back();
1353 Value output = convOp.getOutputs().front();
1355 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1356 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1357 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1359 auto kernelShape = kernelType.getShape();
1360 auto outputShape = outputType.getShape();
1363 auto [khIndex, kwIndex, ohIndex, owIndex] =
1366 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1367 return std::make_tuple(0, 1, 1, 2);
1369 .Case([&](linalg::Conv2DNchwFchwOp op) {
1370 return std::make_tuple(2, 3, 2, 3);
1372 .Case([&](linalg::PoolingNhwcSumOp op) {
1373 return std::make_tuple(0, 1, 1, 2);
1375 .Case([&](linalg::PoolingNchwSumOp op) {
1376 return std::make_tuple(0, 1, 2, 3);
1378 .Case([&](linalg::PoolingNhwcMaxOp op) {
1379 return std::make_tuple(0, 1, 1, 2);
1381 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1382 return std::make_tuple(0, 1, 1, 2);
1384 .Case([&](linalg::PoolingNhwcMinOp op) {
1385 return std::make_tuple(0, 1, 1, 2);
1387 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1388 return std::make_tuple(0, 1, 1, 2);
1390 .Case([&](linalg::PoolingNchwMaxOp op) {
1391 return std::make_tuple(0, 1, 2, 3);
1394 llvm_unreachable(
"unexpected conv2d/pool2d operation.");
1395 return std::make_tuple(0, 0, 0, 0);
1400 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1401 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1402 bool removeH = (khSize == 1 && ohSize == 1);
1403 bool removeW = (kwSize == 1 && owSize == 1);
1404 if (!removeH && !removeW)
1410 RankedTensorType newInputType =
1411 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1412 RankedTensorType newKernelType =
1413 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1414 RankedTensorType newOutputType =
1415 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1420 rewriter, loc, input, newInputType);
1422 rewriter, loc, kernel, newKernelType);
1424 rewriter, loc, output, newOutputType);
1429 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1430 strides.erase(strides.begin() + (removeH ? 0 : 1));
1434 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1435 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1438 auto conv1DOp = rewriter.
create<Conv1DOp>(
1439 loc, newOutputType,
ValueRange{newInput, newKernel},
1440 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1444 rewriter, loc, conv1DOp.getResult(0), output);
1461 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1465 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1472 if (convOp.hasBufferSemantics())
1475 Value input = convOp.getInputs().front();
1476 Value kernel = convOp.getInputs().back();
1477 Value output = convOp.getOutputs().front();
1479 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1480 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1481 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1483 auto kernelShape = kernelType.getShape();
1484 auto outputShape = outputType.getShape();
1488 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1489 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1490 bool removeH = (khSize == 1 && ohSize == 1);
1491 bool removeW = (kwSize == 1 && owSize == 1);
1492 if (!removeH && !removeW)
1498 RankedTensorType newInputType =
1499 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1500 RankedTensorType newKernelType =
1501 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1502 RankedTensorType newOutputType =
1503 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1508 rewriter, loc, input, newInputType);
1510 rewriter, loc, kernel, newKernelType);
1512 rewriter, loc, output, newOutputType);
1516 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1517 strides.erase(strides.begin() + (removeH ? 0 : 1));
1521 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1522 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1525 auto conv1DOp = rewriter.
create<DepthwiseConv1DNwcWcOp>(
1526 loc, newOutputType,
ValueRange{newInput, newKernel},
1527 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1531 rewriter, loc, conv1DOp.getResult(0), output);
1540 if (convOp.hasBufferSemantics())
1543 Value input = convOp.getInputs().front();
1544 Value kernel = convOp.getInputs().back();
1545 Value output = convOp.getOutputs().front();
1547 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1548 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1549 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1551 auto kernelShape = kernelType.getShape();
1552 auto outputShape = outputType.getShape();
1556 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1557 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1558 bool removeH = (khSize == 1 && ohSize == 1);
1559 bool removeW = (kwSize == 1 && owSize == 1);
1560 if (!removeH && !removeW)
1566 RankedTensorType newInputType =
1567 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1568 RankedTensorType newKernelType =
1569 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1570 RankedTensorType newOutputType =
1571 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1576 rewriter, loc, input, newInputType);
1578 rewriter, loc, kernel, newKernelType);
1580 rewriter, loc, output, newOutputType);
1582 auto conv1DOp = rewriter.
create<Conv1DOp>(loc, newOutputType,
1588 rewriter, loc, conv1DOp.getResult(0), output);
1607 PoolingNwcMaxUnsignedOp>,
1610 PoolingNwcMinUnsignedOp>,
A dimensional identifier appearing in an affine expression.
Base type for affine expression.
constexpr bool isa() const
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder)
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const override
OptimizeCopyFn optimizeCopyFn
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.