32#include "llvm/ADT/SmallVectorExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/InterleavedRange.h"
37#include "llvm/Support/raw_ostream.h"
41#define DEBUG_TYPE "linalg-transforms"
60 .Case([&](scf::ForOp forOp) {
61 scf::ForOp partialIteration;
64 return partialIteration->getResults();
65 assert(!partialIteration &&
"expected that loop was not peeled");
66 return forOp->getResults();
75 for (
auto loopOp : loops)
88 if (!e.isFunctionOfDim(dim))
99 return llvm::interleaved(ri,
", ",
"|",
"");
150static FailureOr<SmallVector<std::optional<int64_t>>>
154 int64_t newDim = iteratorTypes.size();
155 iteratorTypes.push_back(iteratorTypes[dim]);
158 indexingMaps.size(), std::nullopt);
160 for (
int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
162 AffineMap map = indexingMaps[operandIdx];
165 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
173 "num results invariant violation");
175 if (!maybeOperandDimensionToPack.has_value()) {
176 newMaps.push_back(map);
181 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
187 newMaps.push_back(map);
190 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
192 indexingMaps = newMaps;
194 return packedDimPerIndexingMap;
200struct PackedOperandsDim {
201 OpFoldResult packedSize;
202 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
206struct PackedOperandsDimList {
207 void pushBack(PackedOperandsDim &&packedOperandsDims) {
208 spec.emplace_back(packedOperandsDims);
211 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
213 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
216 SmallVector<PackedOperandsDim> spec;
222 linalg::PackOp packOp,
223 bool lowerPadLikeWithInsertSlice) {
225 if (!packOp.hasPureTensorSemantics())
229 auto packedTensorType =
230 cast<RankedTensorType>(packOp->getResultTypes().front());
231 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
234 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
243 PackingMetadata packingMetadata;
257 for (
auto [pos, innerSize] :
258 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
260 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
270 rewriter, loc, map, {outerSize, origSize, innerSize});
272 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
274 packingMetadata.reassociations);
275 Value paddingValue = packOp.getPaddingValue();
277 paddingValue = arith::ConstantOp::create(
281 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
282 highs, paddingValue,
false);
284 LDBG() <<
"insertPositions: "
285 << llvm::interleaved(packingMetadata.insertPositions);
286 LDBG() <<
"outerPositions: "
287 << llvm::interleaved(packingMetadata.outerPositions);
288 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
289 LDBG() <<
"packedToStripMinedShapePerm: "
290 << llvm::interleaved(packedToStripMinedShapePerm);
291 LDBG() <<
"reassociations: "
292 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
294 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
295 LDBG() <<
"collapsed type: " << collapsed;
297 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
316 auto insertSliceOp = tensor::InsertSliceOp::create(
317 rewriter, loc, padOp, packOp.getDest(),
320 LDBG() <<
"insert_slice op: " << insertSliceOp;
322 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
330 auto expandShapeResultType =
332 auto reshapeOp = tensor::ExpandShapeOp::create(
333 rewriter, loc, expandShapeResultType, padOp.getResult(),
334 packingMetadata.reassociations);
339 auto transposeOp = linalg::TransposeOp::create(
340 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
342 LDBG() <<
"reshape op: " << reshapeOp;
343 LDBG() <<
"transpPerm: " << llvm::interleaved(transpPerm);
344 LDBG() <<
"transpose op: " << transposeOp;
347 rewriter.
replaceOp(packOp, transposeOp->getResults());
352FailureOr<LowerUnPackOpResult>
354 bool lowerUnpadLikeWithExtractSlice) {
356 if (!unPackOp.hasPureTensorSemantics())
363 auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
364 int64_t packedRank = packedTensorType.getRank();
367 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
368 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
377 auto extractSliceOp = tensor::ExtractSliceOp::create(
378 rewriter, loc, destTensorType, unPackOp.getSource(),
382 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
385 nullptr, extractSliceOp};
390 PackingMetadata packingMetadata;
400 RankedTensorType stripMinedTensorType =
402 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
403 stripMinedTensorType, packingMetadata.reassociations);
410 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
411 stripMinedTensorType.getElementType());
413 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
414 packedToStripMinedShapePerm);
416 LDBG() <<
"insertPositions: "
417 << llvm::interleaved(packingMetadata.insertPositions);
418 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
419 LDBG() <<
"packedToStripMinedShapePerm: "
420 << llvm::interleaved(packedToStripMinedShapePerm);
421 LDBG() <<
"reassociations: "
422 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
424 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
425 LDBG() <<
"collapsed type: " << collapsedType;
428 auto reshapeOp = tensor::CollapseShapeOp::create(
429 rewriter, loc, collapsedType, transposeOp->getResult(0),
430 packingMetadata.reassociations);
433 int64_t destRank = destTensorType.getRank();
434 auto extractSliceOp = tensor::ExtractSliceOp::create(
435 rewriter, loc, destTensorType, reshapeOp->getResult(0),
441 auto copyOp = linalg::CopyOp::create(
442 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
445 rewriter.
replaceOp(unPackOp, copyOp->getResults());
451PackedOperandsDimList::extractPackedDimsForOperand(
int64_t operandPos) {
453 for (
auto &i : spec) {
454 if (!i.packedDimForEachOperand[operandPos].has_value())
456 res.push_back(i.packedDimForEachOperand[operandPos].value());
461SmallVector<OpFoldResult>
462PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
463 SmallVector<OpFoldResult> res;
464 for (
auto &i : spec) {
465 if (!i.packedDimForEachOperand[operandPos].has_value())
467 res.push_back(i.packedSize);
476 linalg::LinalgOp linalgOp,
478 if (packedSizes.size() != linalgOp.getNumLoops()) {
480 "incorrect number of pack sizes");
486 linalgOp.getIteratorTypesArray();
487 LDBG() <<
"Start packing: " << linalgOp;
488 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
489 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
494 PackedOperandsDimList listOfPackedOperandsDim;
495 for (
int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
498 if (maybeConstant.has_value() && maybeConstant.value() == 0)
501 PackedOperandsDim packedOperandsDims;
502 packedOperandsDims.packedSize = packedSizes[i];
503 FailureOr<SmallVector<std::optional<int64_t>>>
504 maybePackedDimForEachOperand =
506 if (failed(maybePackedDimForEachOperand))
508 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
510 LDBG() <<
"++++ After pack size #" << i <<
": " << packedSizes[i];
511 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
512 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
513 LDBG() <<
"packedDimForEachOperand: "
514 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
516 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
522 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
524 for (
const auto &operandsList : {inputOperands, initOperands}) {
525 for (
OpOperand *opOperand : operandsList) {
526 int64_t pos = opOperand->getOperandNumber();
527 Value operand = opOperand->get();
529 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
531 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
532 LDBG() <<
"operand: " << operand;
533 LDBG() <<
"innerPos: " << llvm::interleaved(innerPos);
534 LDBG() <<
"innerPackSizes: " << llvm::interleaved(innerPackSizes);
535 if (innerPackSizes.empty()) {
536 inputsAndInits.push_back(operand);
539 Value dest = linalg::PackOp::createDestinationTensor(
540 rewriter, loc, operand, innerPackSizes, innerPos,
542 ShapedType operandType = cast<ShapedType>(operand.
getType());
543 bool areConstantTiles =
547 if (areConstantTiles && operandType.hasStaticShape() &&
548 !linalg::PackOp::requirePaddingValue(
549 operandType.getShape(), innerPos,
550 cast<ShapedType>(dest.
getType()).getShape(), {},
552 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
553 innerPos, innerPackSizes));
559 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
560 packOps.push_back(linalg::PackOp::create(
561 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
563 inputsAndInits.push_back(packOps.back().getResult());
569 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
571 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
572 auto packedLinalgOp =
573 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.
getTypes(),
574 inputs, inits, indexingMaps, iteratorTypes);
575 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
580 linalg::PackOp maybePackedInit =
581 inits[resultNum].getDefiningOp<linalg::PackOp>();
582 if (!maybePackedInit) {
583 results.push_back(
result);
587 unPackOps.push_back(linalg::UnPackOp::create(
588 rewriter, packedLinalgOp->getLoc(),
result, maybePackedInit.getSource(),
589 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
590 results.push_back(unPackOps.back().getResult());
598 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
627 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
631 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
633 assert(tensorType == transposedValue.
getType() &&
634 "expected tensor type mismatch");
639 llvm::map_to_vector(permutation, [](
int64_t i) ->
unsigned {
return i; });
643 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
647 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
653 auto transposedGenericOp = linalg::GenericOp::create(
657 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
658 operandsRef.take_front(linalgOp.getNumDpsInputs()),
659 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
661 linalgOp.getIteratorTypesArray());
662 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
663 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
665 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
668FailureOr<PackTransposeResult>
670 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
677 linalg::PackOp transposedPackOp =
678 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
680 if (packOp.hasPureBufferSemantics() || !packOp.getResult().hasOneUse())
683 OpOperand &packUse = *packOp->getUses().begin();
684 if (packUse.
getOwner() != linalgOp) {
686 linalgOp,
"not a single use by the LinalgOp target");
689 (!linalgOp.isDpsInit(&packUse) ||
690 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
692 "not produced by the LinalgOp target");
698 int64_t numLeadingDims = packOp.getSourceRank();
699 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
703 if (permutation.empty())
704 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
706 if (innerPerm.empty()) {
709 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
711 llvm::append_range(permutation,
712 llvm::map_range(innerPerm, [&](
int64_t pos) {
713 return numLeadingDims + pos;
725 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
728 linalg::UnPackOp transposedUnPackOp;
731 transposedLinalgOp->getOpOperand(packUseOperandNumber);
732 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
734 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
735 rewriter, loc, transposedResult, innerPerm, outerPerm);
737 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
741 if (packOp.hasPureTensorSemantics())
742 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
767 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
768 assert((mnkPaddedSizesNextMultipleOf.empty() ||
769 mnkPaddedSizesNextMultipleOf.size() == 3) &&
770 "num of packing sizes next multiple should be empty or of size 3");
771 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
774 int64_t numLoops = linalgOp.getNumLoops();
776 LDBG() <<
"need 3+ loops to find a matmul to pack, got " << numLoops
777 <<
" in: " << linalgOp;
779 linalgOp,
"need 3+ loops to find a matmul to pack");
783 int64_t numPackedDims = mnkPackedSizes.size();
785 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
786 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
788 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
789 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
791 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
792 paddedSizesNextMultipleOf[mnkOrder[i]] =
793 mnkPaddedSizesNextMultipleOf.empty() ? 0
794 : mnkPaddedSizesNextMultipleOf[i];
798 FailureOr<ContractionDimensions> maybeDimensions =
800 if (failed(maybeDimensions)) {
801 LDBG() <<
"couldn't infer matmul iterators in: " << linalgOp;
803 "couldn't infer matmul iterators");
811 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
812 kPos = maybeDimensions->k.back();
813 LDBG() <<
"Start packing generic op greedily with (m@" << mPos <<
", n@"
814 << nPos <<
", k@" << kPos <<
"): " << linalgOp;
817 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
819 FailureOr<GenericOp> generalizeResult =
821 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
822 genericOp = *generalizeResult;
830 LDBG() <<
"perm: " << llvm::interleaved(permutation);
833 FailureOr<GenericOp> interchangeResult =
835 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
836 genericOp = *interchangeResult;
837 LDBG() <<
"Generalized Op to pack: " << genericOp;
854 cast<LinalgOp>(genericOp.getOperation())
855 .createLoopRanges(rewriter, genericOp.getLoc());
859 LDBG() <<
"paddedSizesNextMultipleOf: "
860 << llvm::interleaved(paddedSizesNextMultipleOf);
861 LDBG() <<
"loopRanges: "
862 << llvm::interleaved(
863 llvm::map_range(loopRanges, [](
Range r) {
return r.
size; }));
866 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
867 if (paddedSizesNextMultipleOf[i] == 0) {
868 adjustedPackedSizes.push_back(packedSizes[i]);
875 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
876 {loopRanges[adjustedPackedSizes.size()].size,
877 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
879 LDBG() <<
"adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
885 return pack(rewriter, genericOp, adjustedPackedSizes);
898 b.setInsertionPointToStart(
899 &op->getParentOfType<func::FuncOp>().getBody().front());
900 return llvm::map_to_vector<4>(tileSizes, [&](
int64_t s) {
918 auto padValue = padOp.getConstantPaddingValue();
921 if (padValue.getParentBlock() == &padOp.getRegion().front())
923 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
927 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
928 padOp.getResultType(), dynSizes);
931 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
940 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
943 rewriter, padOp.getLoc(),
944 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
948 auto resultType = padOp.getResultType();
952 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
953 if (resultType.isDynamicDim(dim)) {
955 padOp.getSource(), dim));
958 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
960 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
961 dynSizes.push_back(plusHigh);
963 staticSizes.push_back(resultType.getDimSize(dim));
968 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
969 resultType.getElementType(), dynSizes);
973 auto sourceType = padOp.getSourceType();
981 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
989 if (!sliceOp.hasUnitStride())
992 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
996 bool zeroSliceGuard =
true;
998 if (std::optional<bool> control = controlFn(sliceOp))
999 zeroSliceGuard = *control;
1004 FailureOr<TilingResult> tilingResult =
1006 sliceOp.getMixedSizes(), zeroSliceGuard);
1007 if (failed(tilingResult))
1010 RankedTensorType sourceType = sliceOp.getSourceType();
1011 RankedTensorType resultType = sliceOp.getResultType();
1015 if (sourceType.getRank() == resultType.getRank()) {
1016 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1022 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1024 rewriter.
replaceOp(sliceOp, rankReduced);
1034 linalg::PackOp packOp) {
1035 Value input = packOp.getSource();
1037 if (!packOp.hasPureTensorSemantics())
1040 if (!packOp.getPaddingValue()) {
1044 assert(llvm::all_of(packOp.getAllOuterDims(),
1045 [](
int64_t val) { return val == 1; }) &&
1046 "some outer dims are != 1");
1049 ShapedType inputType = packOp.getSourceType();
1050 int64_t inputRank = inputType.getRank();
1053 packOp.getDimAndTileMapping();
1060 for (
int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1063 if (!tileAndPosMapping.count(dimIdx)) {
1064 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1065 assert(inputDimSize == 1 &&
1066 "with all outer dims == 1, this non-tiled input dim should be 1!");
1067 paddedShape.push_back(inputDimSize);
1074 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1078 if (cstTileSize.has_value()) {
1079 paddedShape.push_back(cstTileSize.value());
1084 paddedShape.push_back(ShapedType::kDynamic);
1087 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1090 RankedTensorType::get(paddedShape, inputType.getElementType());
1092 false, loc, builder,
1100static SmallVector<int64_t>
1102 constexpr int64_t kNonTiledMarker = -1;
1104 for (
auto [
index, value] : llvm::enumerate(perm))
1107 vec, [&](
int64_t v) {
return v != kNonTiledMarker; });
1114static SmallVector<int64_t>
1123 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1124 if (llvm::is_contained(innerDimsPos, i)) {
1125 innerDims.push_back(dim++);
1130 outerDims.push_back(dim++);
1131 if (!outerDimsPerm.empty())
1132 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1143 rankReducedOuterDimsPerm =
1145 if (!rankReducedOuterDimsPerm.empty())
1149 perm.append(innerDims);
1157 if (!packOp.hasPureTensorSemantics())
1160 if (llvm::any_of(packOp.getTiledOuterDims(),
1161 [](
int64_t dim) { return dim != 1; })) {
1163 packOp,
"not all outer dimensions of the result are 1s");
1167 auto outerDimsPerm = packOp.getOuterDimsPerm();
1173 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](
int64_t dim) {
1174 static int prev = 0;
1176 if (llvm::is_contained(innerDimsPos, dim))
1181 if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 ||
1182 packOp.getResult().getType().getShape()[dim] != 1))
1189 packOp,
"At least one non-unit and un-tiled outer dim is permuted, "
1190 "this is not supported ATM!");
1195 int64_t srcRank = packOp.getSourceRank();
1214 for (
int64_t i = 0; i < srcRank; i++) {
1222 if (llvm::is_contained(innerDimsPos, i))
1224 srcPermForTranspose.push_back(i);
1226 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1230 ShapedType inputTy = cast<ShapedType>(input.
getType());
1232 for (
int64_t i = 0; i < srcRank; i++) {
1233 if (llvm::is_contained(innerDimsPos, i)) {
1237 if (inputTy.isStaticDim(i))
1238 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(inputTy.getShape()[i]));
1240 shapeForEmptyOp.emplace_back(
1241 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1243 shapeForEmptyOp.append(packOp.getMixedTiles());
1250 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1252 if (auto val = llvm::dyn_cast<Value>(ofr))
1253 return getAsOpFoldResult(val);
1257 LDBG() <<
"Pack permutation: " << packOp;
1258 LDBG() <<
"perm: " << llvm::interleaved(srcPermForTranspose);
1259 LDBG() <<
"Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1261 Value empty = tensor::EmptyOp::create(
1262 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1265 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1266 srcPermForTranspose);
1278 for (
auto size : packOp.getAllOuterDims()) {
1282 for (
auto tileSize : packOp.getMixedTiles()) {
1283 auto [_, tileSizeOfr] =
1285 writeSizes.push_back(tileSizeOfr);
1288 auto insert = tensor::InsertSliceOp::create(
1289 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
1292 rewriter.
replaceOp(packOp, insert.getResult());
1299 if (!unpackOp.hasPureTensorSemantics())
1302 int64_t destRank = unpackOp.getDestRank();
1305 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1306 [](
int64_t dim) { return dim != 1; })) {
1309 "require the tiled outer dimensions of the result are all 1s");
1315 Value source = unpackOp.getSource();
1317 unpackOp.getDimAndTileMapping();
1336 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1345 if (dimAndTileMapping.count(i)) {
1346 extractSliceSizes.push_back(oneIdxAttr);
1352 if (ShapedType::isDynamic(srcShape[i])) {
1354 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1355 extractSliceSizes.push_back(dynamicDim);
1356 shapeForEmptyOp.push_back(dynamicDim);
1358 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1359 if (srcShape[i] != 1)
1360 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1364 if (srcShape[i] != 1) {
1365 readShapeForExtractSlice.push_back(srcShape[i]);
1370 auto mixedTiles = unpackOp.getMixedTiles();
1371 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1372 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1376 auto tileShape = srcShape.drop_front(destRank);
1378 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1379 Type elemType = unpackOp.getSourceType().getElementType();
1380 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1381 Value innerTile = tensor::ExtractSliceOp::create(
1382 rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
1386 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1392 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1394 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1400 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1401 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1402 tileSizes.push_back(
1407 tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
1408 transposedOp.getResult()[0], tileSizes);
1412 for (
int i = 0, idx = 0; i < destRank; ++i) {
1413 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1414 writeSizes.push_back(tileSizes[idx++]);
1416 writeSizes.push_back(oneIdxAttr);
1418 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1419 unpackOp.getDest(), writeSizes);
1420 rewriter.
replaceOp(unpackOp, insert.getResult());
1433template <
typename Conv2DOp,
typename Conv1DOp>
1437 std::optional<DilationsAndStrides> convParams =
1444 if (convOp.hasPureBufferSemantics())
1447 Value input = convOp.getDpsInputs().front();
1448 Value kernel = convOp.getDpsInputs().back();
1449 Value output = convOp.getDpsInits().front();
1451 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1452 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1453 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1455 auto kernelShape = kernelType.getShape();
1456 auto outputShape = outputType.getShape();
1459 int64_t khIndex, kwIndex, ohIndex, owIndex;
1460 if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
1461 std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
1462 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
1463 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
1464 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
1465 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
1471 }
else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
1477 }
else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
1478 std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
1488 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1489 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1490 bool removeH = (khSize == 1 && ohSize == 1);
1491 bool removeW = (kwSize == 1 && owSize == 1);
1492 if (!removeH && !removeW)
1498 RankedTensorType newInputType =
1499 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1500 RankedTensorType newKernelType =
1501 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1502 RankedTensorType newOutputType =
1503 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1508 rewriter, loc, input, newInputType);
1510 rewriter, loc, kernel, newKernelType);
1512 rewriter, loc, output, newOutputType);
1516 strides.erase(strides.begin() + (removeH ? 0 : 1));
1519 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1522 auto conv1DOp = Conv1DOp::create(
1523 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1524 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1528 rewriter, loc, conv1DOp.getResult(0), output);
1545 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1549 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1553FailureOr<DepthwiseConv1DNwcWcOp>
1557 std::optional<DilationsAndStrides> convParams =
1564 if (convOp.hasPureBufferSemantics())
1567 Value input = convOp.getDpsInputs().front();
1568 Value kernel = convOp.getDpsInputs().back();
1569 Value output = convOp.getDpsInits().front();
1571 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1572 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1573 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1575 auto kernelShape = kernelType.getShape();
1576 auto outputShape = outputType.getShape();
1580 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1581 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1582 bool removeH = (khSize == 1 && ohSize == 1);
1583 bool removeW = (kwSize == 1 && owSize == 1);
1584 if (!removeH && !removeW)
1590 RankedTensorType newInputType =
1591 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1592 RankedTensorType newKernelType =
1593 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1594 RankedTensorType newOutputType =
1595 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1600 rewriter, loc, input, newInputType);
1602 rewriter, loc, kernel, newKernelType);
1604 rewriter, loc, output, newOutputType);
1608 strides.erase(strides.begin() + (removeH ? 0 : 1));
1611 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1614 auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
1615 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1616 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1620 rewriter, loc, conv1DOp.getResult(0), output);
1630 std::optional<DilationsAndStrides> convParams =
1635 if (convOp.hasPureBufferSemantics())
1638 Value input = convOp.getDpsInputs().front();
1639 Value kernel = convOp.getDpsInputs().back();
1640 Value output = convOp.getDpsInits().front();
1642 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1643 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1644 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1646 auto kernelShape = kernelType.getShape();
1647 auto outputShape = outputType.getShape();
1651 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1652 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1653 bool removeH = (khSize == 1 && ohSize == 1);
1654 bool removeW = (kwSize == 1 && owSize == 1);
1655 if (!removeH && !removeW)
1661 RankedTensorType newInputType =
1662 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1663 RankedTensorType newKernelType =
1664 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1665 RankedTensorType newOutputType =
1666 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1671 rewriter, loc, input, newInputType);
1673 rewriter, loc, kernel, newKernelType);
1675 rewriter, loc, output, newOutputType);
1678 Conv1DOp::create(rewriter, loc, newOutputType,
1683 rewriter, loc, conv1DOp.getResult(0), output);
1702 PoolingNwcMaxUnsignedOp>,
1705 PoolingNwcMinUnsignedOp>,
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::optional< DilationsAndStrides > matchConvolutionOpOfType(LinalgOp op)
Given a linalg op this function returns DilationsAndStrides if it is a convolution op of type ConvOpT...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.