30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/DebugLog.h"
33#include "llvm/Support/InterleavedRange.h"
34#include "llvm/Support/raw_ostream.h"
37#define DEBUG_TYPE "linalg-transforms"
56 .Case<scf::ForOp>([&](scf::ForOp forOp) {
57 scf::ForOp partialIteration;
60 return partialIteration->getResults();
61 assert(!partialIteration &&
"expected that loop was not peeled");
62 return forOp->getResults();
71 for (
auto loopOp : loops)
84 if (!e.isFunctionOfDim(dim))
95 return llvm::interleaved(ri,
", ",
"|",
"");
146static FailureOr<SmallVector<std::optional<int64_t>>>
150 int64_t newDim = iteratorTypes.size();
151 iteratorTypes.push_back(iteratorTypes[dim]);
154 indexingMaps.size(), std::nullopt);
156 for (
int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
158 AffineMap map = indexingMaps[operandIdx];
161 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
169 "num results invariant violation");
171 if (!maybeOperandDimensionToPack.has_value()) {
172 newMaps.push_back(map);
177 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
183 newMaps.push_back(map);
186 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
188 indexingMaps = newMaps;
190 return packedDimPerIndexingMap;
196struct PackedOperandsDim {
197 OpFoldResult packedSize;
198 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
202struct PackedOperandsDimList {
203 void pushBack(PackedOperandsDim &&packedOperandsDims) {
204 spec.emplace_back(packedOperandsDims);
207 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
209 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
212 SmallVector<PackedOperandsDim> spec;
218 linalg::PackOp packOp,
219 bool lowerPadLikeWithInsertSlice) {
221 auto packedTensorType =
222 cast<RankedTensorType>(packOp->getResultTypes().front());
223 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
226 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
235 PackingMetadata packingMetadata;
249 for (
auto [pos, innerSize] :
250 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
252 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
262 rewriter, loc, map, {outerSize, origSize, innerSize});
264 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
266 packingMetadata.reassociations);
267 Value paddingValue = packOp.getPaddingValue();
269 paddingValue = arith::ConstantOp::create(
273 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
274 highs, paddingValue,
false);
276 LDBG() <<
"insertPositions: "
277 << llvm::interleaved(packingMetadata.insertPositions);
278 LDBG() <<
"outerPositions: "
279 << llvm::interleaved(packingMetadata.outerPositions);
280 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
281 LDBG() <<
"packedToStripMinedShapePerm: "
282 << llvm::interleaved(packedToStripMinedShapePerm);
283 LDBG() <<
"reassociations: "
284 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
286 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
287 LDBG() <<
"collapsed type: " << collapsed;
289 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
308 auto insertSliceOp = tensor::InsertSliceOp::create(
309 rewriter, loc, padOp, packOp.getDest(),
312 LDBG() <<
"insert_slice op: " << insertSliceOp;
314 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
322 auto expandShapeResultType =
324 auto reshapeOp = tensor::ExpandShapeOp::create(
325 rewriter, loc, expandShapeResultType, padOp.getResult(),
326 packingMetadata.reassociations);
331 auto transposeOp = linalg::TransposeOp::create(
332 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
334 LDBG() <<
"reshape op: " << reshapeOp;
335 LDBG() <<
"transpPerm: " << llvm::interleaved(transpPerm);
336 LDBG() <<
"transpose op: " << transposeOp;
339 rewriter.
replaceOp(packOp, transposeOp->getResults());
344FailureOr<LowerUnPackOpResult>
346 bool lowerUnpadLikeWithExtractSlice) {
351 RankedTensorType packedTensorType = unPackOp.getSourceType();
352 int64_t packedRank = packedTensorType.getRank();
355 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
356 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
365 auto extractSliceOp = tensor::ExtractSliceOp::create(
366 rewriter, loc, destTensorType, unPackOp.getSource(),
370 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
373 nullptr, extractSliceOp};
378 PackingMetadata packingMetadata;
388 RankedTensorType stripMinedTensorType =
390 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
391 stripMinedTensorType, packingMetadata.reassociations);
398 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
399 stripMinedTensorType.getElementType());
401 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
402 packedToStripMinedShapePerm);
404 LDBG() <<
"insertPositions: "
405 << llvm::interleaved(packingMetadata.insertPositions);
406 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
407 LDBG() <<
"packedToStripMinedShapePerm: "
408 << llvm::interleaved(packedToStripMinedShapePerm);
409 LDBG() <<
"reassociations: "
410 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
412 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
413 LDBG() <<
"collapsed type: " << collapsedType;
416 auto reshapeOp = tensor::CollapseShapeOp::create(
417 rewriter, loc, collapsedType, transposeOp->getResult(0),
418 packingMetadata.reassociations);
421 int64_t destRank = destTensorType.getRank();
422 auto extractSliceOp = tensor::ExtractSliceOp::create(
423 rewriter, loc, destTensorType, reshapeOp->getResult(0),
429 auto copyOp = linalg::CopyOp::create(
430 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
433 rewriter.
replaceOp(unPackOp, copyOp->getResults());
439PackedOperandsDimList::extractPackedDimsForOperand(
int64_t operandPos) {
441 for (
auto &i : spec) {
442 if (!i.packedDimForEachOperand[operandPos].has_value())
444 res.push_back(i.packedDimForEachOperand[operandPos].value());
449SmallVector<OpFoldResult>
450PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
451 SmallVector<OpFoldResult> res;
452 for (
auto &i : spec) {
453 if (!i.packedDimForEachOperand[operandPos].has_value())
455 res.push_back(i.packedSize);
464 linalg::LinalgOp linalgOp,
466 if (packedSizes.size() != linalgOp.getNumLoops()) {
468 "incorrect number of pack sizes");
474 linalgOp.getIteratorTypesArray();
475 LDBG() <<
"Start packing: " << linalgOp;
476 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
477 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
482 PackedOperandsDimList listOfPackedOperandsDim;
483 for (
int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
486 if (maybeConstant.has_value() && maybeConstant.value() == 0)
489 PackedOperandsDim packedOperandsDims;
490 packedOperandsDims.packedSize = packedSizes[i];
491 FailureOr<SmallVector<std::optional<int64_t>>>
492 maybePackedDimForEachOperand =
494 if (failed(maybePackedDimForEachOperand))
496 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
498 LDBG() <<
"++++ After pack size #" << i <<
": " << packedSizes[i];
499 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
500 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
501 LDBG() <<
"packedDimForEachOperand: "
502 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
504 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
510 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
512 for (
const auto &operandsList : {inputOperands, initOperands}) {
513 for (
OpOperand *opOperand : operandsList) {
514 int64_t pos = opOperand->getOperandNumber();
515 Value operand = opOperand->get();
517 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
519 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
520 LDBG() <<
"operand: " << operand;
521 LDBG() <<
"innerPos: " << llvm::interleaved(innerPos);
522 LDBG() <<
"innerPackSizes: " << llvm::interleaved(innerPackSizes);
523 if (innerPackSizes.empty()) {
524 inputsAndInits.push_back(operand);
527 Value dest = linalg::PackOp::createDestinationTensor(
528 rewriter, loc, operand, innerPackSizes, innerPos,
530 ShapedType operandType = cast<ShapedType>(operand.
getType());
531 bool areConstantTiles =
535 if (areConstantTiles && operandType.hasStaticShape() &&
536 !linalg::PackOp::requirePaddingValue(
537 operandType.getShape(), innerPos,
538 cast<ShapedType>(dest.
getType()).getShape(), {},
540 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
541 innerPos, innerPackSizes));
547 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
548 packOps.push_back(linalg::PackOp::create(
549 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
551 inputsAndInits.push_back(packOps.back());
557 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
559 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
560 auto packedLinalgOp =
561 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.
getTypes(),
562 inputs, inits, indexingMaps, iteratorTypes);
563 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
568 linalg::PackOp maybePackedInit =
569 inits[resultNum].getDefiningOp<linalg::PackOp>();
570 if (!maybePackedInit) {
571 results.push_back(
result);
575 unPackOps.push_back(linalg::UnPackOp::create(
576 rewriter, packedLinalgOp->getLoc(),
result, maybePackedInit.getSource(),
577 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
578 results.push_back(unPackOps.back());
586 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
615 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
619 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
621 assert(tensorType == transposedValue.
getType() &&
622 "expected tensor type mismatch");
627 llvm::map_range(permutation, [](
int64_t i) ->
unsigned {
return i; }));
631 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
635 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
641 auto transposedGenericOp = linalg::GenericOp::create(
645 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
646 operandsRef.take_front(linalgOp.getNumDpsInputs()),
647 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
649 linalgOp.getIteratorTypesArray());
650 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
651 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
653 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
656FailureOr<PackTransposeResult>
658 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
665 linalg::PackOp transposedPackOp =
666 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
668 if (!packOp.getResult().hasOneUse())
671 OpOperand &packUse = *packOp->getUses().begin();
672 if (packUse.
getOwner() != linalgOp) {
674 linalgOp,
"not a single use by the LinalgOp target");
677 (!linalgOp.isDpsInit(&packUse) ||
678 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
680 "not produced by the LinalgOp target");
686 int64_t numLeadingDims = packOp.getSourceRank();
687 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
691 if (permutation.empty())
692 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
694 if (innerPerm.empty()) {
697 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
699 llvm::append_range(permutation,
700 llvm::map_range(innerPerm, [&](
int64_t pos) {
701 return numLeadingDims + pos;
713 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
716 linalg::UnPackOp transposedUnPackOp;
719 transposedLinalgOp->getOpOperand(packUseOperandNumber);
720 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
722 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
723 rewriter, loc, transposedResult, innerPerm, outerPerm);
725 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
729 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
752 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
753 assert((mnkPaddedSizesNextMultipleOf.empty() ||
754 mnkPaddedSizesNextMultipleOf.size() == 3) &&
755 "num of packing sizes next multiple should be empty or of size 3");
756 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
759 int64_t numLoops = linalgOp.getNumLoops();
761 LDBG() <<
"need 3+ loops to find a matmul to pack, got " << numLoops
762 <<
" in: " << linalgOp;
764 linalgOp,
"need 3+ loops to find a matmul to pack");
768 int64_t numPackedDims = mnkPackedSizes.size();
770 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
771 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
773 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
774 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
776 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
777 paddedSizesNextMultipleOf[mnkOrder[i]] =
778 mnkPaddedSizesNextMultipleOf.empty() ? 0
779 : mnkPaddedSizesNextMultipleOf[i];
783 FailureOr<ContractionDimensions> maybeDimensions =
785 if (failed(maybeDimensions)) {
786 LDBG() <<
"couldn't infer matmul iterators in: " << linalgOp;
788 "couldn't infer matmul iterators");
796 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
797 kPos = maybeDimensions->k.back();
798 LDBG() <<
"Start packing generic op greedily with (m@" << mPos <<
", n@"
799 << nPos <<
", k@" << kPos <<
"): " << linalgOp;
802 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
804 FailureOr<GenericOp> generalizeResult =
806 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
807 genericOp = *generalizeResult;
815 LDBG() <<
"perm: " << llvm::interleaved(permutation);
818 FailureOr<GenericOp> interchangeResult =
820 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
821 genericOp = *interchangeResult;
822 LDBG() <<
"Generalized Op to pack: " << genericOp;
839 cast<LinalgOp>(genericOp.getOperation())
840 .createLoopRanges(rewriter, genericOp.getLoc());
844 LDBG() <<
"paddedSizesNextMultipleOf: "
845 << llvm::interleaved(paddedSizesNextMultipleOf);
846 LDBG() <<
"loopRanges: "
847 << llvm::interleaved(
848 llvm::map_range(loopRanges, [](
Range r) {
return r.
size; }));
851 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
852 if (paddedSizesNextMultipleOf[i] == 0) {
853 adjustedPackedSizes.push_back(packedSizes[i]);
860 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
861 {loopRanges[adjustedPackedSizes.size()].size,
862 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
864 LDBG() <<
"adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
870 return pack(rewriter, genericOp, adjustedPackedSizes);
883 b.setInsertionPointToStart(
884 &op->getParentOfType<func::FuncOp>().getBody().front());
885 return llvm::to_vector<4>(map_range(tileSizes, [&](
int64_t s) {
903 auto padValue = padOp.getConstantPaddingValue();
906 if (padValue.getParentBlock() == &padOp.getRegion().front())
908 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
912 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
913 padOp.getResultType(), dynSizes);
916 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
925 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
928 rewriter, padOp.getLoc(),
929 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
933 auto resultType = padOp.getResultType();
937 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
938 if (resultType.isDynamicDim(dim)) {
940 padOp.getSource(), dim));
943 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
945 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
946 dynSizes.push_back(plusHigh);
948 staticSizes.push_back(resultType.getDimSize(dim));
953 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
954 resultType.getElementType(), dynSizes);
958 auto sourceType = padOp.getSourceType();
966 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
974 if (!sliceOp.hasUnitStride())
977 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
981 bool zeroSliceGuard =
true;
983 if (std::optional<bool> control = controlFn(sliceOp))
984 zeroSliceGuard = *control;
989 FailureOr<TilingResult> tilingResult =
991 sliceOp.getMixedSizes(), zeroSliceGuard);
992 if (failed(tilingResult))
995 RankedTensorType sourceType = sliceOp.getSourceType();
996 RankedTensorType resultType = sliceOp.getResultType();
1000 if (sourceType.getRank() == resultType.getRank()) {
1001 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1007 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1009 rewriter.
replaceOp(sliceOp, rankReduced);
1019 linalg::PackOp packOp) {
1020 Value input = packOp.getSource();
1021 if (!packOp.getPaddingValue()) {
1025 assert(llvm::all_of(packOp.getAllOuterDims(),
1026 [](
int64_t val) { return val == 1; }) &&
1027 "some outer dims are != 1");
1030 ShapedType inputType = packOp.getSourceType();
1031 int64_t inputRank = inputType.getRank();
1034 packOp.getDimAndTileMapping();
1041 for (
int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1044 if (!tileAndPosMapping.count(dimIdx)) {
1045 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1046 assert(inputDimSize == 1 &&
1047 "with all outer dims == 1, this non-tiled input dim should be 1!");
1048 paddedShape.push_back(inputDimSize);
1055 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1059 if (cstTileSize.has_value()) {
1060 paddedShape.push_back(cstTileSize.value());
1065 paddedShape.push_back(ShapedType::kDynamic);
1068 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1071 RankedTensorType::get(paddedShape, inputType.getElementType());
1073 false, loc, builder,
1081static SmallVector<int64_t>
1083 constexpr int64_t kNonTiledMarker = -1;
1085 for (
auto [
index, value] : llvm::enumerate(perm))
1088 vec, [&](
int64_t v) {
return v != kNonTiledMarker; });
1095static SmallVector<int64_t>
1104 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1105 if (llvm::is_contained(innerDimsPos, i)) {
1106 innerDims.push_back(dim++);
1111 outerDims.push_back(dim++);
1112 if (!outerDimsPerm.empty())
1113 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1124 rankReducedOuterDimsPerm =
1126 if (!rankReducedOuterDimsPerm.empty())
1130 perm.append(innerDims);
1137 if (llvm::any_of(packOp.getTiledOuterDims(),
1138 [](
int64_t dim) { return dim != 1; })) {
1140 packOp,
"not all outer dimensions of the result are 1s");
1144 auto outerDimsPerm = packOp.getOuterDimsPerm();
1150 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](
int64_t dim) {
1151 static int prev = 0;
1153 if (llvm::is_contained(innerDimsPos, dim))
1158 if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
1159 packOp.getType().getShape()[dim] != 1))
1166 packOp,
"At least one non-unit and un-tiled outer dim is permuted, "
1167 "this is not supported ATM!");
1174 int64_t srcRank = packOp.getSourceRank();
1175 int64_t destRank = packOp.getDestRank();
1194 for (
int64_t i = 0; i < srcRank; i++) {
1202 if (llvm::is_contained(innerDimsPos, i))
1204 srcPermForTranspose.push_back(i);
1206 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1210 ShapedType inputTy = cast<ShapedType>(input.
getType());
1212 for (
int64_t i = 0; i < srcRank; i++) {
1213 if (llvm::is_contained(innerDimsPos, i)) {
1217 if (inputTy.isStaticDim(i))
1218 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(inputTy.getShape()[i]));
1220 shapeForEmptyOp.emplace_back(
1221 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1223 shapeForEmptyOp.append(packOp.getMixedTiles());
1227 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1229 if (auto val = llvm::dyn_cast<Value>(ofr))
1230 return getAsOpFoldResult(val);
1234 LDBG() <<
"Pack permutation: " << packOp;
1235 LDBG() <<
"perm: " << llvm::interleaved(srcPermForTranspose);
1236 LDBG() <<
"Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1238 Value empty = tensor::EmptyOp::create(
1239 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1242 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1243 srcPermForTranspose);
1255 for (
auto size : packOp.getAllOuterDims()) {
1259 for (
auto tileSize : packOp.getMixedTiles()) {
1260 auto [_, tileSizeOfr] =
1262 writeSizes.push_back(tileSizeOfr);
1270 auto insert = tensor::InsertSliceOp::create(
1271 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
1272 writeOffsets, writeSizes, writeStrides);
1275 rewriter.
replaceOp(packOp, insert.getResult());
1282 int64_t srcRank = unpackOp.getSourceRank();
1283 int64_t destRank = unpackOp.getDestRank();
1286 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1287 [](
int64_t dim) { return dim != 1; })) {
1290 "require the tiled outer dimensions of the result are all 1s");
1296 Value source = unpackOp.getSource();
1298 unpackOp.getDimAndTileMapping();
1321 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1330 if (dimAndTileMapping.count(i)) {
1331 extractSliceSizes.push_back(oneIdxAttr);
1337 if (ShapedType::isDynamic(srcShape[i])) {
1339 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1340 extractSliceSizes.push_back(dynamicDim);
1341 shapeForEmptyOp.push_back(dynamicDim);
1343 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1344 if (srcShape[i] != 1)
1345 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1349 if (srcShape[i] != 1) {
1350 readShapeForExtractSlice.push_back(srcShape[i]);
1355 auto mixedTiles = unpackOp.getMixedTiles();
1356 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1357 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1361 auto tileShape = srcShape.drop_front(destRank);
1363 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1364 Type elemType = unpackOp.getSourceType().getElementType();
1365 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1366 Value innerTile = tensor::ExtractSliceOp::create(
1367 rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
1368 extractSliceSizes, extractSliceStrides);
1372 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1378 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1380 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1384 int numLoops = shapeForEmptyOp.size();
1389 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1390 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1391 tileSizes.push_back(
1396 tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
1397 tileOffsets, tileSizes, tileStrides);
1403 for (
int i = 0, idx = 0; i < destRank; ++i) {
1404 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1405 writeSizes.push_back(tileSizes[idx++]);
1407 writeSizes.push_back(oneIdxAttr);
1409 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1410 unpackOp.getDest(), writeOffsets,
1411 writeSizes, writeStrides);
1412 rewriter.
replaceOp(unpackOp, insert.getResult());
1425template <
typename Conv2DOp,
typename Conv1DOp>
1428 if (convOp.hasPureBufferSemantics())
1431 Value input = convOp.getInputs().front();
1432 Value kernel = convOp.getInputs().back();
1433 Value output = convOp.getOutputs().front();
1435 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1436 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1437 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1439 auto kernelShape = kernelType.getShape();
1440 auto outputShape = outputType.getShape();
1443 auto [khIndex, kwIndex, ohIndex, owIndex] =
1446 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1447 return std::make_tuple(0, 1, 1, 2);
1449 .Case([&](linalg::Conv2DNchwFchwOp op) {
1450 return std::make_tuple(2, 3, 2, 3);
1452 .Case([&](linalg::PoolingNhwcSumOp op) {
1453 return std::make_tuple(0, 1, 1, 2);
1455 .Case([&](linalg::PoolingNchwSumOp op) {
1456 return std::make_tuple(0, 1, 2, 3);
1458 .Case([&](linalg::PoolingNhwcMaxOp op) {
1459 return std::make_tuple(0, 1, 1, 2);
1461 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1462 return std::make_tuple(0, 1, 1, 2);
1464 .Case([&](linalg::PoolingNhwcMinOp op) {
1465 return std::make_tuple(0, 1, 1, 2);
1467 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1468 return std::make_tuple(0, 1, 1, 2);
1470 .Case([&](linalg::PoolingNchwMaxOp op) {
1471 return std::make_tuple(0, 1, 2, 3);
1473 .DefaultUnreachable(
"unexpected conv2d/pool2d operation.");
1477 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1478 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1479 bool removeH = (khSize == 1 && ohSize == 1);
1480 bool removeW = (kwSize == 1 && owSize == 1);
1481 if (!removeH && !removeW)
1487 RankedTensorType newInputType =
1488 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1489 RankedTensorType newKernelType =
1490 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1491 RankedTensorType newOutputType =
1492 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1497 rewriter, loc, input, newInputType);
1499 rewriter, loc, kernel, newKernelType);
1501 rewriter, loc, output, newOutputType);
1506 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1507 strides.erase(strides.begin() + (removeH ? 0 : 1));
1511 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1512 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1515 auto conv1DOp = Conv1DOp::create(
1516 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1517 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1521 rewriter, loc, conv1DOp.getResult(0), output);
1538 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1542 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1546FailureOr<DepthwiseConv1DNwcWcOp>
1549 if (convOp.hasPureBufferSemantics())
1552 Value input = convOp.getInputs().front();
1553 Value kernel = convOp.getInputs().back();
1554 Value output = convOp.getOutputs().front();
1556 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1557 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1558 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1560 auto kernelShape = kernelType.getShape();
1561 auto outputShape = outputType.getShape();
1565 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1566 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1567 bool removeH = (khSize == 1 && ohSize == 1);
1568 bool removeW = (kwSize == 1 && owSize == 1);
1569 if (!removeH && !removeW)
1575 RankedTensorType newInputType =
1576 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1577 RankedTensorType newKernelType =
1578 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1579 RankedTensorType newOutputType =
1580 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1585 rewriter, loc, input, newInputType);
1587 rewriter, loc, kernel, newKernelType);
1589 rewriter, loc, output, newOutputType);
1593 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<
int64_t>());
1594 strides.erase(strides.begin() + (removeH ? 0 : 1));
1598 llvm::to_vector<4>(convOp.getDilations().getValues<
int64_t>());
1599 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1602 auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
1603 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1604 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1608 rewriter, loc, conv1DOp.getResult(0), output);
1617 if (convOp.hasPureBufferSemantics())
1620 Value input = convOp.getInputs().front();
1621 Value kernel = convOp.getInputs().back();
1622 Value output = convOp.getOutputs().front();
1624 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1625 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1626 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1628 auto kernelShape = kernelType.getShape();
1629 auto outputShape = outputType.getShape();
1633 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1634 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1635 bool removeH = (khSize == 1 && ohSize == 1);
1636 bool removeW = (kwSize == 1 && owSize == 1);
1637 if (!removeH && !removeW)
1643 RankedTensorType newInputType =
1644 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1645 RankedTensorType newKernelType =
1646 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1647 RankedTensorType newOutputType =
1648 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1653 rewriter, loc, input, newInputType);
1655 rewriter, loc, kernel, newKernelType);
1657 rewriter, loc, output, newOutputType);
1660 Conv1DOp::create(rewriter, loc, newOutputType,
1665 rewriter, loc, conv1DOp.getResult(0), output);
1684 PoolingNwcMaxUnsignedOp>,
1687 PoolingNwcMinUnsignedOp>,
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.