30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/DebugLog.h"
33#include "llvm/Support/InterleavedRange.h"
34#include "llvm/Support/raw_ostream.h"
37#define DEBUG_TYPE "linalg-transforms"
56 .Case<scf::ForOp>([&](scf::ForOp forOp) {
57 scf::ForOp partialIteration;
60 return partialIteration->getResults();
61 assert(!partialIteration &&
"expected that loop was not peeled");
62 return forOp->getResults();
71 for (
auto loopOp : loops)
84 if (!e.isFunctionOfDim(dim))
95 return llvm::interleaved(ri,
", ",
"|",
"");
146static FailureOr<SmallVector<std::optional<int64_t>>>
150 int64_t newDim = iteratorTypes.size();
151 iteratorTypes.push_back(iteratorTypes[dim]);
154 indexingMaps.size(), std::nullopt);
156 for (
int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
158 AffineMap map = indexingMaps[operandIdx];
161 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
169 "num results invariant violation");
171 if (!maybeOperandDimensionToPack.has_value()) {
172 newMaps.push_back(map);
177 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
183 newMaps.push_back(map);
186 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
188 indexingMaps = newMaps;
190 return packedDimPerIndexingMap;
196struct PackedOperandsDim {
197 OpFoldResult packedSize;
198 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
202struct PackedOperandsDimList {
203 void pushBack(PackedOperandsDim &&packedOperandsDims) {
204 spec.emplace_back(packedOperandsDims);
207 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
209 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
212 SmallVector<PackedOperandsDim> spec;
218 linalg::PackOp packOp,
219 bool lowerPadLikeWithInsertSlice) {
221 auto packedTensorType =
222 cast<RankedTensorType>(packOp->getResultTypes().front());
223 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
226 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
235 PackingMetadata packingMetadata;
249 for (
auto [pos, innerSize] :
250 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
252 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
262 rewriter, loc, map, {outerSize, origSize, innerSize});
264 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
266 packingMetadata.reassociations);
267 Value paddingValue = packOp.getPaddingValue();
269 paddingValue = arith::ConstantOp::create(
273 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
274 highs, paddingValue,
false);
276 LDBG() <<
"insertPositions: "
277 << llvm::interleaved(packingMetadata.insertPositions);
278 LDBG() <<
"outerPositions: "
279 << llvm::interleaved(packingMetadata.outerPositions);
280 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
281 LDBG() <<
"packedToStripMinedShapePerm: "
282 << llvm::interleaved(packedToStripMinedShapePerm);
283 LDBG() <<
"reassociations: "
284 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
286 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
287 LDBG() <<
"collapsed type: " << collapsed;
289 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
308 auto insertSliceOp = tensor::InsertSliceOp::create(
309 rewriter, loc, padOp, packOp.getDest(),
312 LDBG() <<
"insert_slice op: " << insertSliceOp;
314 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
322 auto expandShapeResultType =
324 auto reshapeOp = tensor::ExpandShapeOp::create(
325 rewriter, loc, expandShapeResultType, padOp.getResult(),
326 packingMetadata.reassociations);
331 auto transposeOp = linalg::TransposeOp::create(
332 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
334 LDBG() <<
"reshape op: " << reshapeOp;
335 LDBG() <<
"transpPerm: " << llvm::interleaved(transpPerm);
336 LDBG() <<
"transpose op: " << transposeOp;
339 rewriter.
replaceOp(packOp, transposeOp->getResults());
344FailureOr<LowerUnPackOpResult>
346 bool lowerUnpadLikeWithExtractSlice) {
351 RankedTensorType packedTensorType = unPackOp.getSourceType();
352 int64_t packedRank = packedTensorType.getRank();
355 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
356 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
365 auto extractSliceOp = tensor::ExtractSliceOp::create(
366 rewriter, loc, destTensorType, unPackOp.getSource(),
370 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
373 nullptr, extractSliceOp};
378 PackingMetadata packingMetadata;
388 RankedTensorType stripMinedTensorType =
390 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
391 stripMinedTensorType, packingMetadata.reassociations);
398 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
399 stripMinedTensorType.getElementType());
401 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
402 packedToStripMinedShapePerm);
404 LDBG() <<
"insertPositions: "
405 << llvm::interleaved(packingMetadata.insertPositions);
406 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
407 LDBG() <<
"packedToStripMinedShapePerm: "
408 << llvm::interleaved(packedToStripMinedShapePerm);
409 LDBG() <<
"reassociations: "
410 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
412 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
413 LDBG() <<
"collapsed type: " << collapsedType;
416 auto reshapeOp = tensor::CollapseShapeOp::create(
417 rewriter, loc, collapsedType, transposeOp->getResult(0),
418 packingMetadata.reassociations);
421 int64_t destRank = destTensorType.getRank();
422 auto extractSliceOp = tensor::ExtractSliceOp::create(
423 rewriter, loc, destTensorType, reshapeOp->getResult(0),
429 auto copyOp = linalg::CopyOp::create(
430 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
433 rewriter.
replaceOp(unPackOp, copyOp->getResults());
439PackedOperandsDimList::extractPackedDimsForOperand(
int64_t operandPos) {
441 for (
auto &i : spec) {
442 if (!i.packedDimForEachOperand[operandPos].has_value())
444 res.push_back(i.packedDimForEachOperand[operandPos].value());
449SmallVector<OpFoldResult>
450PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
451 SmallVector<OpFoldResult> res;
452 for (
auto &i : spec) {
453 if (!i.packedDimForEachOperand[operandPos].has_value())
455 res.push_back(i.packedSize);
464 linalg::LinalgOp linalgOp,
466 if (packedSizes.size() != linalgOp.getNumLoops()) {
468 "incorrect number of pack sizes");
474 linalgOp.getIteratorTypesArray();
475 LDBG() <<
"Start packing: " << linalgOp;
476 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
477 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
482 PackedOperandsDimList listOfPackedOperandsDim;
483 for (
int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
486 if (maybeConstant.has_value() && maybeConstant.value() == 0)
489 PackedOperandsDim packedOperandsDims;
490 packedOperandsDims.packedSize = packedSizes[i];
491 FailureOr<SmallVector<std::optional<int64_t>>>
492 maybePackedDimForEachOperand =
494 if (failed(maybePackedDimForEachOperand))
496 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
498 LDBG() <<
"++++ After pack size #" << i <<
": " << packedSizes[i];
499 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
500 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
501 LDBG() <<
"packedDimForEachOperand: "
502 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
504 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
510 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
512 for (
const auto &operandsList : {inputOperands, initOperands}) {
513 for (
OpOperand *opOperand : operandsList) {
514 int64_t pos = opOperand->getOperandNumber();
515 Value operand = opOperand->get();
517 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
519 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
520 LDBG() <<
"operand: " << operand;
521 LDBG() <<
"innerPos: " << llvm::interleaved(innerPos);
522 LDBG() <<
"innerPackSizes: " << llvm::interleaved(innerPackSizes);
523 if (innerPackSizes.empty()) {
524 inputsAndInits.push_back(operand);
527 Value dest = linalg::PackOp::createDestinationTensor(
528 rewriter, loc, operand, innerPackSizes, innerPos,
530 ShapedType operandType = cast<ShapedType>(operand.
getType());
531 bool areConstantTiles =
535 if (areConstantTiles && operandType.hasStaticShape() &&
536 !linalg::PackOp::requirePaddingValue(
537 operandType.getShape(), innerPos,
538 cast<ShapedType>(dest.
getType()).getShape(), {},
540 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
541 innerPos, innerPackSizes));
547 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
548 packOps.push_back(linalg::PackOp::create(
549 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
551 inputsAndInits.push_back(packOps.back());
557 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
559 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
560 auto packedLinalgOp =
561 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.
getTypes(),
562 inputs, inits, indexingMaps, iteratorTypes);
563 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
568 linalg::PackOp maybePackedInit =
569 inits[resultNum].getDefiningOp<linalg::PackOp>();
570 if (!maybePackedInit) {
571 results.push_back(
result);
575 unPackOps.push_back(linalg::UnPackOp::create(
576 rewriter, packedLinalgOp->getLoc(),
result, maybePackedInit.getSource(),
577 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
578 results.push_back(unPackOps.back());
586 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
615 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
619 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
621 assert(tensorType == transposedValue.
getType() &&
622 "expected tensor type mismatch");
627 llvm::map_range(permutation, [](
int64_t i) ->
unsigned {
return i; }));
631 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
635 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
641 auto transposedGenericOp = linalg::GenericOp::create(
645 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
646 operandsRef.take_front(linalgOp.getNumDpsInputs()),
647 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
649 linalgOp.getIteratorTypesArray());
650 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
651 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
653 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
656FailureOr<PackTransposeResult>
658 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
665 linalg::PackOp transposedPackOp =
666 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
668 if (!packOp.getResult().hasOneUse())
671 OpOperand &packUse = *packOp->getUses().begin();
672 if (packUse.
getOwner() != linalgOp) {
674 linalgOp,
"not a single use by the LinalgOp target");
677 (!linalgOp.isDpsInit(&packUse) ||
678 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
680 "not produced by the LinalgOp target");
686 int64_t numLeadingDims = packOp.getSourceRank();
687 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
691 if (permutation.empty())
692 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
694 if (innerPerm.empty()) {
697 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
699 llvm::append_range(permutation,
700 llvm::map_range(innerPerm, [&](
int64_t pos) {
701 return numLeadingDims + pos;
713 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
716 linalg::UnPackOp transposedUnPackOp;
719 transposedLinalgOp->getOpOperand(packUseOperandNumber);
720 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
722 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
723 rewriter, loc, transposedResult, innerPerm, outerPerm);
725 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
729 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
752 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
753 assert((mnkPaddedSizesNextMultipleOf.empty() ||
754 mnkPaddedSizesNextMultipleOf.size() == 3) &&
755 "num of packing sizes next multiple should be empty or of size 3");
756 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
759 int64_t numLoops = linalgOp.getNumLoops();
761 LDBG() <<
"need 3+ loops to find a matmul to pack, got " << numLoops
762 <<
" in: " << linalgOp;
764 linalgOp,
"need 3+ loops to find a matmul to pack");
768 int64_t numPackedDims = mnkPackedSizes.size();
770 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
771 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
773 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
774 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
776 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
777 paddedSizesNextMultipleOf[mnkOrder[i]] =
778 mnkPaddedSizesNextMultipleOf.empty() ? 0
779 : mnkPaddedSizesNextMultipleOf[i];
783 FailureOr<ContractionDimensions> maybeDimensions =
785 if (failed(maybeDimensions)) {
786 LDBG() <<
"couldn't infer matmul iterators in: " << linalgOp;
788 "couldn't infer matmul iterators");
796 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
797 kPos = maybeDimensions->k.back();
798 LDBG() <<
"Start packing generic op greedily with (m@" << mPos <<
", n@"
799 << nPos <<
", k@" << kPos <<
"): " << linalgOp;
802 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
804 FailureOr<GenericOp> generalizeResult =
806 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
807 genericOp = *generalizeResult;
815 LDBG() <<
"perm: " << llvm::interleaved(permutation);
818 FailureOr<GenericOp> interchangeResult =
820 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
821 genericOp = *interchangeResult;
822 LDBG() <<
"Generalized Op to pack: " << genericOp;
839 cast<LinalgOp>(genericOp.getOperation())
840 .createLoopRanges(rewriter, genericOp.getLoc());
844 LDBG() <<
"paddedSizesNextMultipleOf: "
845 << llvm::interleaved(paddedSizesNextMultipleOf);
846 LDBG() <<
"loopRanges: "
847 << llvm::interleaved(
848 llvm::map_range(loopRanges, [](
Range r) {
return r.
size; }));
851 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
852 if (paddedSizesNextMultipleOf[i] == 0) {
853 adjustedPackedSizes.push_back(packedSizes[i]);
860 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
861 {loopRanges[adjustedPackedSizes.size()].size,
862 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
864 LDBG() <<
"adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
870 return pack(rewriter, genericOp, adjustedPackedSizes);
883 b.setInsertionPointToStart(
884 &op->getParentOfType<func::FuncOp>().getBody().front());
885 return llvm::to_vector<4>(map_range(tileSizes, [&](
int64_t s) {
903 auto padValue = padOp.getConstantPaddingValue();
906 if (padValue.getParentBlock() == &padOp.getRegion().front())
908 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
912 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
913 padOp.getResultType(), dynSizes);
916 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
925 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
928 rewriter, padOp.getLoc(),
929 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
933 auto resultType = padOp.getResultType();
937 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
938 if (resultType.isDynamicDim(dim)) {
940 padOp.getSource(), dim));
943 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
945 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
946 dynSizes.push_back(plusHigh);
948 staticSizes.push_back(resultType.getDimSize(dim));
953 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
954 resultType.getElementType(), dynSizes);
958 auto sourceType = padOp.getSourceType();
966 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
974 if (!sliceOp.hasUnitStride())
977 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
981 bool zeroSliceGuard =
true;
983 if (std::optional<bool> control = controlFn(sliceOp))
984 zeroSliceGuard = *control;
989 FailureOr<TilingResult> tilingResult =
991 sliceOp.getMixedSizes(), zeroSliceGuard);
992 if (failed(tilingResult))
995 RankedTensorType sourceType = sliceOp.getSourceType();
996 RankedTensorType resultType = sliceOp.getResultType();
1000 if (sourceType.getRank() == resultType.getRank()) {
1001 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1007 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1009 rewriter.
replaceOp(sliceOp, rankReduced);
1019 linalg::PackOp packOp) {
1020 Value input = packOp.getSource();
1021 if (!packOp.getPaddingValue()) {
1025 assert(llvm::all_of(packOp.getAllOuterDims(),
1026 [](
int64_t val) { return val == 1; }) &&
1027 "some outer dims are != 1");
1030 ShapedType inputType = packOp.getSourceType();
1031 int64_t inputRank = inputType.getRank();
1034 packOp.getDimAndTileMapping();
1041 for (
int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1044 if (!tileAndPosMapping.count(dimIdx)) {
1045 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1046 assert(inputDimSize == 1 &&
1047 "with all outer dims == 1, this non-tiled input dim should be 1!");
1048 paddedShape.push_back(inputDimSize);
1055 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1059 if (cstTileSize.has_value()) {
1060 paddedShape.push_back(cstTileSize.value());
1065 paddedShape.push_back(ShapedType::kDynamic);
1068 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1071 RankedTensorType::get(paddedShape, inputType.getElementType());
1073 false, loc, builder,
1081static SmallVector<int64_t>
1083 constexpr int64_t kNonTiledMarker = -1;
1085 for (
auto [
index, value] : llvm::enumerate(perm))
1088 vec, [&](
int64_t v) {
return v != kNonTiledMarker; });
1095static SmallVector<int64_t>
1104 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1105 if (llvm::is_contained(innerDimsPos, i)) {
1106 innerDims.push_back(dim++);
1111 outerDims.push_back(dim++);
1112 if (!outerDimsPerm.empty())
1113 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1124 rankReducedOuterDimsPerm =
1126 if (!rankReducedOuterDimsPerm.empty())
1130 perm.append(innerDims);
1137 if (llvm::any_of(packOp.getTiledOuterDims(),
1138 [](
int64_t dim) { return dim != 1; })) {
1140 packOp,
"not all outer dimensions of the result are 1s");
1144 auto outerDimsPerm = packOp.getOuterDimsPerm();
1150 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](
int64_t dim) {
1151 static int prev = 0;
1153 if (llvm::is_contained(innerDimsPos, dim))
1158 if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
1159 packOp.getType().getShape()[dim] != 1))
1166 packOp,
"At least one non-unit and un-tiled outer dim is permuted, "
1167 "this is not supported ATM!");
1172 int64_t srcRank = packOp.getSourceRank();
1191 for (
int64_t i = 0; i < srcRank; i++) {
1199 if (llvm::is_contained(innerDimsPos, i))
1201 srcPermForTranspose.push_back(i);
1203 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1207 ShapedType inputTy = cast<ShapedType>(input.
getType());
1209 for (
int64_t i = 0; i < srcRank; i++) {
1210 if (llvm::is_contained(innerDimsPos, i)) {
1214 if (inputTy.isStaticDim(i))
1215 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(inputTy.getShape()[i]));
1217 shapeForEmptyOp.emplace_back(
1218 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1220 shapeForEmptyOp.append(packOp.getMixedTiles());
1224 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1226 if (auto val = llvm::dyn_cast<Value>(ofr))
1227 return getAsOpFoldResult(val);
1231 LDBG() <<
"Pack permutation: " << packOp;
1232 LDBG() <<
"perm: " << llvm::interleaved(srcPermForTranspose);
1233 LDBG() <<
"Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1235 Value empty = tensor::EmptyOp::create(
1236 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1239 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1240 srcPermForTranspose);
1252 for (
auto size : packOp.getAllOuterDims()) {
1256 for (
auto tileSize : packOp.getMixedTiles()) {
1257 auto [_, tileSizeOfr] =
1259 writeSizes.push_back(tileSizeOfr);
1262 auto insert = tensor::InsertSliceOp::create(
1263 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
1266 rewriter.
replaceOp(packOp, insert.getResult());
1273 int64_t destRank = unpackOp.getDestRank();
1276 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1277 [](
int64_t dim) { return dim != 1; })) {
1280 "require the tiled outer dimensions of the result are all 1s");
1286 Value source = unpackOp.getSource();
1288 unpackOp.getDimAndTileMapping();
1307 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1316 if (dimAndTileMapping.count(i)) {
1317 extractSliceSizes.push_back(oneIdxAttr);
1323 if (ShapedType::isDynamic(srcShape[i])) {
1325 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1326 extractSliceSizes.push_back(dynamicDim);
1327 shapeForEmptyOp.push_back(dynamicDim);
1329 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1330 if (srcShape[i] != 1)
1331 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1335 if (srcShape[i] != 1) {
1336 readShapeForExtractSlice.push_back(srcShape[i]);
1341 auto mixedTiles = unpackOp.getMixedTiles();
1342 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1343 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1347 auto tileShape = srcShape.drop_front(destRank);
1349 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1350 Type elemType = unpackOp.getSourceType().getElementType();
1351 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1352 Value innerTile = tensor::ExtractSliceOp::create(
1353 rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
1357 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1363 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1365 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1371 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1372 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1373 tileSizes.push_back(
1378 tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
1379 transposedOp.getResult()[0], tileSizes);
1383 for (
int i = 0, idx = 0; i < destRank; ++i) {
1384 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1385 writeSizes.push_back(tileSizes[idx++]);
1387 writeSizes.push_back(oneIdxAttr);
1389 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1390 unpackOp.getDest(), writeSizes);
1391 rewriter.
replaceOp(unpackOp, insert.getResult());
1404template <
typename Conv2DOp,
typename Conv1DOp>
1407 if (convOp.hasPureBufferSemantics())
1410 Value input = convOp.getInputs().front();
1411 Value kernel = convOp.getInputs().back();
1412 Value output = convOp.getOutputs().front();
1414 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1415 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1416 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1418 auto kernelShape = kernelType.getShape();
1419 auto outputShape = outputType.getShape();
1422 auto [khIndex, kwIndex, ohIndex, owIndex] =
1425 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1426 return std::make_tuple(0, 1, 1, 2);
1428 .Case([&](linalg::Conv2DNchwFchwOp op) {
1429 return std::make_tuple(2, 3, 2, 3);
1431 .Case([&](linalg::PoolingNhwcSumOp op) {
1432 return std::make_tuple(0, 1, 1, 2);
1434 .Case([&](linalg::PoolingNchwSumOp op) {
1435 return std::make_tuple(0, 1, 2, 3);
1437 .Case([&](linalg::PoolingNhwcMaxOp op) {
1438 return std::make_tuple(0, 1, 1, 2);
1440 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1441 return std::make_tuple(0, 1, 1, 2);
1443 .Case([&](linalg::PoolingNhwcMinOp op) {
1444 return std::make_tuple(0, 1, 1, 2);
1446 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1447 return std::make_tuple(0, 1, 1, 2);
1449 .Case([&](linalg::PoolingNchwMaxOp op) {
1450 return std::make_tuple(0, 1, 2, 3);
1452 .DefaultUnreachable(
"unexpected conv2d/pool2d operation.");
1456 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1457 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1458 bool removeH = (khSize == 1 && ohSize == 1);
1459 bool removeW = (kwSize == 1 && owSize == 1);
1460 if (!removeH && !removeW)
1466 RankedTensorType newInputType =
1467 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1468 RankedTensorType newKernelType =
1469 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1470 RankedTensorType newOutputType =
1471 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1476 rewriter, loc, input, newInputType);
1478 rewriter, loc, kernel, newKernelType);
1480 rewriter, loc, output, newOutputType);
1485 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1486 strides.erase(strides.begin() + (removeH ? 0 : 1));
1490 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1491 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1494 auto conv1DOp = Conv1DOp::create(
1495 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1496 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1500 rewriter, loc, conv1DOp.getResult(0), output);
1517 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1521 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1525FailureOr<DepthwiseConv1DNwcWcOp>
1528 if (convOp.hasPureBufferSemantics())
1531 Value input = convOp.getInputs().front();
1532 Value kernel = convOp.getInputs().back();
1533 Value output = convOp.getOutputs().front();
1535 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1536 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1537 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1539 auto kernelShape = kernelType.getShape();
1540 auto outputShape = outputType.getShape();
1544 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1545 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1546 bool removeH = (khSize == 1 && ohSize == 1);
1547 bool removeW = (kwSize == 1 && owSize == 1);
1548 if (!removeH && !removeW)
1554 RankedTensorType newInputType =
1555 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1556 RankedTensorType newKernelType =
1557 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1558 RankedTensorType newOutputType =
1559 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1564 rewriter, loc, input, newInputType);
1566 rewriter, loc, kernel, newKernelType);
1568 rewriter, loc, output, newOutputType);
1572 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<
int64_t>());
1573 strides.erase(strides.begin() + (removeH ? 0 : 1));
1577 llvm::to_vector<4>(convOp.getDilations().getValues<
int64_t>());
1578 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1581 auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
1582 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1583 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1587 rewriter, loc, conv1DOp.getResult(0), output);
1596 if (convOp.hasPureBufferSemantics())
1599 Value input = convOp.getInputs().front();
1600 Value kernel = convOp.getInputs().back();
1601 Value output = convOp.getOutputs().front();
1603 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1604 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1605 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1607 auto kernelShape = kernelType.getShape();
1608 auto outputShape = outputType.getShape();
1612 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1613 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1614 bool removeH = (khSize == 1 && ohSize == 1);
1615 bool removeW = (kwSize == 1 && owSize == 1);
1616 if (!removeH && !removeW)
1622 RankedTensorType newInputType =
1623 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1624 RankedTensorType newKernelType =
1625 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1626 RankedTensorType newOutputType =
1627 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1632 rewriter, loc, input, newInputType);
1634 rewriter, loc, kernel, newKernelType);
1636 rewriter, loc, output, newOutputType);
1639 Conv1DOp::create(rewriter, loc, newOutputType,
1644 rewriter, loc, conv1DOp.getResult(0), output);
1663 PoolingNwcMaxUnsignedOp>,
1666 PoolingNwcMinUnsignedOp>,
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.