32#include "llvm/ADT/SmallVectorExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/DebugLog.h"
36#include "llvm/Support/InterleavedRange.h"
37#include "llvm/Support/raw_ostream.h"
41#define DEBUG_TYPE "linalg-transforms"
60 .Case([&](scf::ForOp forOp) {
61 scf::ForOp partialIteration;
64 return partialIteration->getResults();
65 assert(!partialIteration &&
"expected that loop was not peeled");
66 return forOp->getResults();
75 for (
auto loopOp : loops)
88 if (!e.isFunctionOfDim(dim))
99 return llvm::interleaved(ri,
", ",
"|",
"");
150static FailureOr<SmallVector<std::optional<int64_t>>>
154 int64_t newDim = iteratorTypes.size();
155 iteratorTypes.push_back(iteratorTypes[dim]);
158 indexingMaps.size(), std::nullopt);
160 for (
int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
162 AffineMap map = indexingMaps[operandIdx];
165 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
173 "num results invariant violation");
175 if (!maybeOperandDimensionToPack.has_value()) {
176 newMaps.push_back(map);
181 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
187 newMaps.push_back(map);
190 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
192 indexingMaps = newMaps;
194 return packedDimPerIndexingMap;
200struct PackedOperandsDim {
201 OpFoldResult packedSize;
202 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
206struct PackedOperandsDimList {
207 void pushBack(PackedOperandsDim &&packedOperandsDims) {
208 spec.emplace_back(packedOperandsDims);
211 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
213 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
216 SmallVector<PackedOperandsDim> spec;
222 linalg::PackOp packOp,
223 bool lowerPadLikeWithInsertSlice) {
225 if (!packOp.hasPureTensorSemantics())
229 auto packedTensorType =
230 cast<RankedTensorType>(packOp->getResultTypes().front());
231 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
234 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
243 PackingMetadata packingMetadata;
257 for (
auto [pos, innerSize] :
258 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
260 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
270 rewriter, loc, map, {outerSize, origSize, innerSize});
272 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
274 packingMetadata.reassociations);
275 Value paddingValue = packOp.getPaddingValue();
277 paddingValue = arith::ConstantOp::create(
281 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
282 highs, paddingValue,
false);
284 LDBG() <<
"insertPositions: "
285 << llvm::interleaved(packingMetadata.insertPositions);
286 LDBG() <<
"outerPositions: "
287 << llvm::interleaved(packingMetadata.outerPositions);
288 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
289 LDBG() <<
"packedToStripMinedShapePerm: "
290 << llvm::interleaved(packedToStripMinedShapePerm);
291 LDBG() <<
"reassociations: "
292 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
294 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
295 LDBG() <<
"collapsed type: " << collapsed;
297 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
316 auto insertSliceOp = tensor::InsertSliceOp::create(
317 rewriter, loc, padOp, packOp.getDest(),
320 LDBG() <<
"insert_slice op: " << insertSliceOp;
322 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
330 auto expandShapeResultType =
332 auto reshapeOp = tensor::ExpandShapeOp::create(
333 rewriter, loc, expandShapeResultType, padOp.getResult(),
334 packingMetadata.reassociations);
339 auto transposeOp = linalg::TransposeOp::create(
340 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
342 LDBG() <<
"reshape op: " << reshapeOp;
343 LDBG() <<
"transpPerm: " << llvm::interleaved(transpPerm);
344 LDBG() <<
"transpose op: " << transposeOp;
347 rewriter.
replaceOp(packOp, transposeOp->getResults());
352FailureOr<LowerUnPackOpResult>
354 bool lowerUnpadLikeWithExtractSlice) {
356 if (!unPackOp.hasPureTensorSemantics())
363 auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
364 int64_t packedRank = packedTensorType.getRank();
367 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
368 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
377 auto extractSliceOp = tensor::ExtractSliceOp::create(
378 rewriter, loc, destTensorType, unPackOp.getSource(),
382 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
385 nullptr, extractSliceOp,
391 PackingMetadata packingMetadata;
401 RankedTensorType stripMinedTensorType =
403 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
404 stripMinedTensorType, packingMetadata.reassociations);
411 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
412 stripMinedTensorType.getElementType());
414 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
415 packedToStripMinedShapePerm);
417 LDBG() <<
"insertPositions: "
418 << llvm::interleaved(packingMetadata.insertPositions);
419 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
420 LDBG() <<
"packedToStripMinedShapePerm: "
421 << llvm::interleaved(packedToStripMinedShapePerm);
422 LDBG() <<
"reassociations: "
423 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
425 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
426 LDBG() <<
"collapsed type: " << collapsedType;
429 auto reshapeOp = tensor::CollapseShapeOp::create(
430 rewriter, loc, collapsedType, transposeOp->getResult(0),
431 packingMetadata.reassociations);
434 int64_t destRank = destTensorType.getRank();
435 auto extractSliceOp = tensor::ExtractSliceOp::create(
436 rewriter, loc, destTensorType, reshapeOp->getResult(0),
442 auto copyOp = linalg::CopyOp::create(
443 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
446 rewriter.
replaceOp(unPackOp, copyOp->getResults());
453PackedOperandsDimList::extractPackedDimsForOperand(
int64_t operandPos) {
455 for (
auto &i : spec) {
456 if (!i.packedDimForEachOperand[operandPos].has_value())
458 res.push_back(i.packedDimForEachOperand[operandPos].value());
463SmallVector<OpFoldResult>
464PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
465 SmallVector<OpFoldResult> res;
466 for (
auto &i : spec) {
467 if (!i.packedDimForEachOperand[operandPos].has_value())
469 res.push_back(i.packedSize);
478 linalg::LinalgOp linalgOp,
480 if (packedSizes.size() != linalgOp.getNumLoops()) {
482 "incorrect number of pack sizes");
488 linalgOp.getIteratorTypesArray();
489 LDBG() <<
"Start packing: " << linalgOp;
490 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
491 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
496 PackedOperandsDimList listOfPackedOperandsDim;
497 for (
int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
500 if (maybeConstant.has_value() && maybeConstant.value() == 0)
503 PackedOperandsDim packedOperandsDims;
504 packedOperandsDims.packedSize = packedSizes[i];
505 FailureOr<SmallVector<std::optional<int64_t>>>
506 maybePackedDimForEachOperand =
508 if (failed(maybePackedDimForEachOperand))
510 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
512 LDBG() <<
"++++ After pack size #" << i <<
": " << packedSizes[i];
513 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
514 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
515 LDBG() <<
"packedDimForEachOperand: "
516 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
518 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
524 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
526 for (
const auto &operandsList : {inputOperands, initOperands}) {
527 for (
OpOperand *opOperand : operandsList) {
528 int64_t pos = opOperand->getOperandNumber();
529 Value operand = opOperand->get();
531 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
533 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
534 LDBG() <<
"operand: " << operand;
535 LDBG() <<
"innerPos: " << llvm::interleaved(innerPos);
536 LDBG() <<
"innerPackSizes: " << llvm::interleaved(innerPackSizes);
537 if (innerPackSizes.empty()) {
538 inputsAndInits.push_back(operand);
541 Value dest = linalg::PackOp::createDestinationTensor(
542 rewriter, loc, operand, innerPackSizes, innerPos,
544 ShapedType operandType = cast<ShapedType>(operand.
getType());
545 bool areConstantTiles =
549 if (areConstantTiles && operandType.hasStaticShape() &&
550 !linalg::PackOp::requirePaddingValue(
551 operandType.getShape(), innerPos,
552 cast<ShapedType>(dest.
getType()).getShape(), {},
554 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
555 innerPos, innerPackSizes));
561 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
562 packOps.push_back(linalg::PackOp::create(
563 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
565 inputsAndInits.push_back(packOps.back().getResult());
571 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
573 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
574 auto packedLinalgOp =
575 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.
getTypes(),
576 inputs, inits, indexingMaps, iteratorTypes);
577 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
582 linalg::PackOp maybePackedInit =
583 inits[resultNum].getDefiningOp<linalg::PackOp>();
584 if (!maybePackedInit) {
585 results.push_back(
result);
589 unPackOps.push_back(linalg::UnPackOp::create(
590 rewriter, packedLinalgOp->getLoc(),
result, maybePackedInit.getSource(),
591 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
592 results.push_back(unPackOps.back().getResult());
600 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
629 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
633 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
635 assert(tensorType == transposedValue.
getType() &&
636 "expected tensor type mismatch");
641 llvm::map_to_vector(permutation, [](
int64_t i) ->
unsigned {
return i; });
645 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
649 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
655 auto transposedGenericOp = linalg::GenericOp::create(
659 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
660 operandsRef.take_front(linalgOp.getNumDpsInputs()),
661 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
663 linalgOp.getIteratorTypesArray());
664 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
665 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
667 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
670FailureOr<PackTransposeResult>
672 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
679 linalg::PackOp transposedPackOp =
680 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
682 if (packOp.hasPureBufferSemantics() || !packOp.getResult().hasOneUse())
685 OpOperand &packUse = *packOp->getUses().begin();
686 if (packUse.
getOwner() != linalgOp) {
688 linalgOp,
"not a single use by the LinalgOp target");
691 (!linalgOp.isDpsInit(&packUse) ||
692 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
694 "not produced by the LinalgOp target");
700 int64_t numLeadingDims = packOp.getSourceRank();
701 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
705 if (permutation.empty())
706 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
708 if (innerPerm.empty()) {
711 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
713 llvm::append_range(permutation,
714 llvm::map_range(innerPerm, [&](
int64_t pos) {
715 return numLeadingDims + pos;
727 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
730 linalg::UnPackOp transposedUnPackOp;
733 transposedLinalgOp->getOpOperand(packUseOperandNumber);
734 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
736 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
737 rewriter, loc, transposedResult, innerPerm, outerPerm);
739 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
743 if (packOp.hasPureTensorSemantics())
744 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
769 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
770 assert((mnkPaddedSizesNextMultipleOf.empty() ||
771 mnkPaddedSizesNextMultipleOf.size() == 3) &&
772 "num of packing sizes next multiple should be empty or of size 3");
773 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
776 int64_t numLoops = linalgOp.getNumLoops();
778 LDBG() <<
"need 3+ loops to find a matmul to pack, got " << numLoops
779 <<
" in: " << linalgOp;
781 linalgOp,
"need 3+ loops to find a matmul to pack");
785 int64_t numPackedDims = mnkPackedSizes.size();
787 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
788 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
790 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
791 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
793 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
794 paddedSizesNextMultipleOf[mnkOrder[i]] =
795 mnkPaddedSizesNextMultipleOf.empty() ? 0
796 : mnkPaddedSizesNextMultipleOf[i];
800 FailureOr<ContractionDimensions> maybeDimensions =
802 if (failed(maybeDimensions)) {
803 LDBG() <<
"couldn't infer matmul iterators in: " << linalgOp;
805 "couldn't infer matmul iterators");
813 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
814 kPos = maybeDimensions->k.back();
815 LDBG() <<
"Start packing generic op greedily with (m@" << mPos <<
", n@"
816 << nPos <<
", k@" << kPos <<
"): " << linalgOp;
819 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
821 FailureOr<GenericOp> generalizeResult =
823 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
824 genericOp = *generalizeResult;
832 LDBG() <<
"perm: " << llvm::interleaved(permutation);
835 FailureOr<GenericOp> interchangeResult =
837 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
838 genericOp = *interchangeResult;
839 LDBG() <<
"Generalized Op to pack: " << genericOp;
856 cast<LinalgOp>(genericOp.getOperation())
857 .createLoopRanges(rewriter, genericOp.getLoc());
861 LDBG() <<
"paddedSizesNextMultipleOf: "
862 << llvm::interleaved(paddedSizesNextMultipleOf);
863 LDBG() <<
"loopRanges: "
864 << llvm::interleaved(
865 llvm::map_range(loopRanges, [](
Range r) {
return r.
size; }));
868 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
869 if (paddedSizesNextMultipleOf[i] == 0) {
870 adjustedPackedSizes.push_back(packedSizes[i]);
877 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
878 {loopRanges[adjustedPackedSizes.size()].size,
879 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
881 LDBG() <<
"adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
887 return pack(rewriter, genericOp, adjustedPackedSizes);
900 b.setInsertionPointToStart(
901 &op->getParentOfType<func::FuncOp>().getBody().front());
902 return llvm::map_to_vector<4>(tileSizes, [&](
int64_t s) {
920 auto padValue = padOp.getConstantPaddingValue();
923 if (padValue.getParentBlock() == &padOp.getRegion().front())
925 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
929 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
930 padOp.getResultType(), dynSizes);
933 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
942 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
945 rewriter, padOp.getLoc(),
946 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
950 auto resultType = padOp.getResultType();
954 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
955 if (resultType.isDynamicDim(dim)) {
957 padOp.getSource(), dim));
960 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
962 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
963 dynSizes.push_back(plusHigh);
965 staticSizes.push_back(resultType.getDimSize(dim));
970 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
971 resultType.getElementType(), dynSizes);
975 auto sourceType = padOp.getSourceType();
983 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
991 if (!sliceOp.hasUnitStride())
994 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
998 bool zeroSliceGuard =
true;
1000 if (std::optional<bool> control = controlFn(sliceOp))
1001 zeroSliceGuard = *control;
1006 FailureOr<TilingResult> tilingResult =
1008 sliceOp.getMixedSizes(), zeroSliceGuard);
1009 if (failed(tilingResult))
1012 RankedTensorType sourceType = sliceOp.getSourceType();
1013 RankedTensorType resultType = sliceOp.getResultType();
1017 if (sourceType.getRank() == resultType.getRank()) {
1018 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1024 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1026 rewriter.
replaceOp(sliceOp, rankReduced);
1036 linalg::PackOp packOp) {
1037 Value input = packOp.getSource();
1039 if (!packOp.hasPureTensorSemantics())
1042 if (!packOp.getPaddingValue()) {
1046 assert(llvm::all_of(packOp.getAllOuterDims(),
1047 [](
int64_t val) { return val == 1; }) &&
1048 "some outer dims are != 1");
1051 ShapedType inputType = packOp.getSourceType();
1052 int64_t inputRank = inputType.getRank();
1055 packOp.getDimAndTileMapping();
1062 for (
int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1065 if (!tileAndPosMapping.count(dimIdx)) {
1066 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1067 assert(inputDimSize == 1 &&
1068 "with all outer dims == 1, this non-tiled input dim should be 1!");
1069 paddedShape.push_back(inputDimSize);
1076 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1080 if (cstTileSize.has_value()) {
1081 paddedShape.push_back(cstTileSize.value());
1086 paddedShape.push_back(ShapedType::kDynamic);
1089 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1092 RankedTensorType::get(paddedShape, inputType.getElementType());
1094 false, loc, builder,
1102static SmallVector<int64_t>
1104 constexpr int64_t kNonTiledMarker = -1;
1106 for (
auto [
index, value] : llvm::enumerate(perm))
1109 vec, [&](
int64_t v) {
return v != kNonTiledMarker; });
1116static SmallVector<int64_t>
1125 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1126 if (llvm::is_contained(innerDimsPos, i)) {
1127 innerDims.push_back(dim++);
1132 outerDims.push_back(dim++);
1133 if (!outerDimsPerm.empty())
1134 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1145 rankReducedOuterDimsPerm =
1147 if (!rankReducedOuterDimsPerm.empty())
1151 perm.append(innerDims);
1159 if (!packOp.hasPureTensorSemantics())
1162 if (llvm::any_of(packOp.getTiledOuterDims(),
1163 [](
int64_t dim) { return dim != 1; })) {
1165 packOp,
"not all outer dimensions of the result are 1s");
1169 auto outerDimsPerm = packOp.getOuterDimsPerm();
1175 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](
int64_t dim) {
1176 static int prev = 0;
1178 if (llvm::is_contained(innerDimsPos, dim))
1183 if (dim < prev && (packOp.getResult().getType().getShape()[prev] != 1 ||
1184 packOp.getResult().getType().getShape()[dim] != 1))
1191 packOp,
"At least one non-unit and un-tiled outer dim is permuted, "
1192 "this is not supported ATM!");
1197 int64_t srcRank = packOp.getSourceRank();
1216 for (
int64_t i = 0; i < srcRank; i++) {
1224 if (llvm::is_contained(innerDimsPos, i))
1226 srcPermForTranspose.push_back(i);
1228 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1232 ShapedType inputTy = cast<ShapedType>(input.
getType());
1234 for (
int64_t i = 0; i < srcRank; i++) {
1235 if (llvm::is_contained(innerDimsPos, i)) {
1239 if (inputTy.isStaticDim(i))
1240 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(inputTy.getShape()[i]));
1242 shapeForEmptyOp.emplace_back(
1243 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1245 shapeForEmptyOp.append(packOp.getMixedTiles());
1252 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1254 if (auto val = llvm::dyn_cast<Value>(ofr))
1255 return getAsOpFoldResult(val);
1259 LDBG() <<
"Pack permutation: " << packOp;
1260 LDBG() <<
"perm: " << llvm::interleaved(srcPermForTranspose);
1261 LDBG() <<
"Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1263 Value empty = tensor::EmptyOp::create(
1264 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1267 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1268 srcPermForTranspose);
1280 for (
auto size : packOp.getAllOuterDims()) {
1284 for (
auto tileSize : packOp.getMixedTiles()) {
1285 auto [_, tileSizeOfr] =
1287 writeSizes.push_back(tileSizeOfr);
1290 auto insert = tensor::InsertSliceOp::create(
1291 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
1294 rewriter.
replaceOp(packOp, insert.getResult());
1301 if (!unpackOp.hasPureTensorSemantics())
1304 int64_t destRank = unpackOp.getDestRank();
1307 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1308 [](
int64_t dim) { return dim != 1; })) {
1311 "require the tiled outer dimensions of the result are all 1s");
1317 Value source = unpackOp.getSource();
1319 unpackOp.getDimAndTileMapping();
1338 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1347 if (dimAndTileMapping.count(i)) {
1348 extractSliceSizes.push_back(oneIdxAttr);
1354 if (ShapedType::isDynamic(srcShape[i])) {
1356 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1357 extractSliceSizes.push_back(dynamicDim);
1358 shapeForEmptyOp.push_back(dynamicDim);
1360 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1361 if (srcShape[i] != 1)
1362 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1366 if (srcShape[i] != 1) {
1367 readShapeForExtractSlice.push_back(srcShape[i]);
1372 auto mixedTiles = unpackOp.getMixedTiles();
1373 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1374 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1378 auto tileShape = srcShape.drop_front(destRank);
1380 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1381 Type elemType = unpackOp.getSourceType().getElementType();
1382 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1383 Value innerTile = tensor::ExtractSliceOp::create(
1384 rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
1388 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1394 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1396 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1402 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1403 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1404 tileSizes.push_back(
1409 tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
1410 transposedOp.getResult()[0], tileSizes);
1414 for (
int i = 0, idx = 0; i < destRank; ++i) {
1415 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1416 writeSizes.push_back(tileSizes[idx++]);
1418 writeSizes.push_back(oneIdxAttr);
1420 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1421 unpackOp.getDest(), writeSizes);
1422 rewriter.
replaceOp(unpackOp, insert.getResult());
1435template <
typename Conv2DOp,
typename Conv1DOp>
1439 std::optional<DilationsAndStrides> convParams =
1446 if (convOp.hasPureBufferSemantics())
1449 Value input = convOp.getDpsInputs().front();
1450 Value kernel = convOp.getDpsInputs().back();
1451 Value output = convOp.getDpsInits().front();
1453 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1454 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1455 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1457 auto kernelShape = kernelType.getShape();
1458 auto outputShape = outputType.getShape();
1461 int64_t khIndex, kwIndex, ohIndex, owIndex;
1462 if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
1463 std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
1464 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
1465 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
1466 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
1467 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
1473 }
else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
1479 }
else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
1480 std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
1490 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1491 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1492 bool removeH = (khSize == 1 && ohSize == 1);
1493 bool removeW = (kwSize == 1 && owSize == 1);
1494 if (!removeH && !removeW)
1500 RankedTensorType newInputType =
1501 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1502 RankedTensorType newKernelType =
1503 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1504 RankedTensorType newOutputType =
1505 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1510 rewriter, loc, input, newInputType);
1512 rewriter, loc, kernel, newKernelType);
1514 rewriter, loc, output, newOutputType);
1518 strides.erase(strides.begin() + (removeH ? 0 : 1));
1521 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1524 auto conv1DOp = Conv1DOp::create(
1525 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1526 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1530 rewriter, loc, conv1DOp.getResult(0), output);
1547 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1551 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1555FailureOr<DepthwiseConv1DNwcWcOp>
1559 std::optional<DilationsAndStrides> convParams =
1566 if (convOp.hasPureBufferSemantics())
1569 Value input = convOp.getDpsInputs().front();
1570 Value kernel = convOp.getDpsInputs().back();
1571 Value output = convOp.getDpsInits().front();
1573 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1574 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1575 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1577 auto kernelShape = kernelType.getShape();
1578 auto outputShape = outputType.getShape();
1582 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1583 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1584 bool removeH = (khSize == 1 && ohSize == 1);
1585 bool removeW = (kwSize == 1 && owSize == 1);
1586 if (!removeH && !removeW)
1592 RankedTensorType newInputType =
1593 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1594 RankedTensorType newKernelType =
1595 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1596 RankedTensorType newOutputType =
1597 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1602 rewriter, loc, input, newInputType);
1604 rewriter, loc, kernel, newKernelType);
1606 rewriter, loc, output, newOutputType);
1610 strides.erase(strides.begin() + (removeH ? 0 : 1));
1613 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1616 auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
1617 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1618 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1622 rewriter, loc, conv1DOp.getResult(0), output);
1632 std::optional<DilationsAndStrides> convParams =
1637 if (convOp.hasPureBufferSemantics())
1640 Value input = convOp.getDpsInputs().front();
1641 Value kernel = convOp.getDpsInputs().back();
1642 Value output = convOp.getDpsInits().front();
1644 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1645 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1646 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1648 auto kernelShape = kernelType.getShape();
1649 auto outputShape = outputType.getShape();
1653 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1654 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1655 bool removeH = (khSize == 1 && ohSize == 1);
1656 bool removeW = (kwSize == 1 && owSize == 1);
1657 if (!removeH && !removeW)
1663 RankedTensorType newInputType =
1664 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1665 RankedTensorType newKernelType =
1666 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1667 RankedTensorType newOutputType =
1668 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1673 rewriter, loc, input, newInputType);
1675 rewriter, loc, kernel, newKernelType);
1677 rewriter, loc, output, newOutputType);
1680 Conv1DOp::create(rewriter, loc, newOutputType,
1685 rewriter, loc, conv1DOp.getResult(0), output);
1704 PoolingNwcMaxUnsignedOp>,
1707 PoolingNwcMinUnsignedOp>,
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::optional< DilationsAndStrides > matchConvolutionOpOfType(LinalgOp op)
Given a linalg op this function returns DilationsAndStrides if it is a convolution op of type ConvOpT...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice + copy.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.