30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/DebugLog.h"
33 #include "llvm/Support/InterleavedRange.h"
34 #include "llvm/Support/raw_ostream.h"
37 #define DEBUG_TYPE "linalg-transforms"
56 .Case<scf::ForOp>([&](scf::ForOp forOp) {
57 scf::ForOp partialIteration;
60 return partialIteration->getResults();
61 assert(!partialIteration &&
"expected that loop was not peeled");
62 return forOp->getResults();
71 for (
auto loopOp : loops)
84 if (!e.isFunctionOfDim(dim))
95 return llvm::interleaved(ri,
", ",
"|",
"");
146 static FailureOr<SmallVector<std::optional<int64_t>>>
150 int64_t newDim = iteratorTypes.size();
151 iteratorTypes.push_back(iteratorTypes[dim]);
154 indexingMaps.size(), std::nullopt);
156 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
158 AffineMap map = indexingMaps[operandIdx];
161 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
169 "num results invariant violation");
171 if (!maybeOperandDimensionToPack.has_value()) {
172 newMaps.push_back(map);
177 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
183 newMaps.push_back(map);
186 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
188 indexingMaps = newMaps;
190 return packedDimPerIndexingMap;
196 struct PackedOperandsDim {
202 struct PackedOperandsDimList {
203 void pushBack(PackedOperandsDim &&packedOperandsDims) {
204 spec.emplace_back(packedOperandsDims);
218 linalg::PackOp packOp,
219 bool lowerPadLikeWithInsertSlice) {
221 auto packedTensorType =
222 cast<RankedTensorType>(packOp->getResultTypes().front());
223 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
226 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
235 PackingMetadata packingMetadata = computePackingMetadata(
236 packedTensorType.getRank(), packOp.getInnerDimsPos());
250 for (
auto [pos, innerSize] :
251 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
253 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
263 rewriter, loc, map, {outerSize, origSize, innerSize});
265 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
267 packingMetadata.reassociations);
268 Value paddingValue = packOp.getPaddingValue();
270 paddingValue = arith::ConstantOp::create(
274 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
275 highs, paddingValue,
false);
277 LDBG() <<
"insertPositions: "
278 << llvm::interleaved(packingMetadata.insertPositions);
279 LDBG() <<
"outerPositions: "
280 << llvm::interleaved(packingMetadata.outerPositions);
281 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
282 LDBG() <<
"packedToStripMinedShapePerm: "
283 << llvm::interleaved(packedToStripMinedShapePerm);
284 LDBG() <<
"reassociations: "
285 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
287 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
288 LDBG() <<
"collapsed type: " << collapsed;
290 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
309 auto insertSliceOp = tensor::InsertSliceOp::create(
310 rewriter, loc, padOp, packOp.getDest(),
313 LDBG() <<
"insert_slice op: " << insertSliceOp;
315 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
323 auto expandShapeResultType =
325 auto reshapeOp = tensor::ExpandShapeOp::create(
326 rewriter, loc, expandShapeResultType, padOp.getResult(),
327 packingMetadata.reassociations);
332 auto transposeOp = linalg::TransposeOp::create(
333 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
335 LDBG() <<
"reshape op: " << reshapeOp;
336 LDBG() <<
"transpPerm: " << llvm::interleaved(transpPerm);
337 LDBG() <<
"transpose op: " << transposeOp;
340 rewriter.
replaceOp(packOp, transposeOp->getResults());
345 FailureOr<LowerUnPackOpResult>
347 bool lowerUnpadLikeWithExtractSlice) {
352 RankedTensorType packedTensorType = unPackOp.getSourceType();
353 int64_t packedRank = packedTensorType.getRank();
356 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
357 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
366 auto extractSliceOp = tensor::ExtractSliceOp::create(
367 rewriter, loc, destTensorType, unPackOp.getSource(),
371 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
374 nullptr, extractSliceOp};
379 PackingMetadata packingMetadata;
389 RankedTensorType stripMinedTensorType =
391 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
392 stripMinedTensorType, packingMetadata.reassociations);
399 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
400 stripMinedTensorType.getElementType());
402 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
403 packedToStripMinedShapePerm);
405 LDBG() <<
"insertPositions: "
406 << llvm::interleaved(packingMetadata.insertPositions);
407 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
408 LDBG() <<
"packedToStripMinedShapePerm: "
409 << llvm::interleaved(packedToStripMinedShapePerm);
410 LDBG() <<
"reassociations: "
411 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
413 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
414 LDBG() <<
"collapsed type: " << collapsedType;
417 auto reshapeOp = tensor::CollapseShapeOp::create(
418 rewriter, loc, collapsedType, transposeOp->getResult(0),
419 packingMetadata.reassociations);
422 int64_t destRank = destTensorType.getRank();
423 auto extractSliceOp = tensor::ExtractSliceOp::create(
424 rewriter, loc, destTensorType, reshapeOp->getResult(0),
430 auto copyOp = linalg::CopyOp::create(
431 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
434 rewriter.
replaceOp(unPackOp, copyOp->getResults());
440 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
442 for (
auto &i : spec) {
443 if (!i.packedDimForEachOperand[operandPos].has_value())
445 res.push_back(i.packedDimForEachOperand[operandPos].value());
451 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
453 for (
auto &i : spec) {
454 if (!i.packedDimForEachOperand[operandPos].has_value())
456 res.push_back(i.packedSize);
465 linalg::LinalgOp linalgOp,
467 if (packedSizes.size() != linalgOp.getNumLoops()) {
469 "incorrect number of pack sizes");
475 linalgOp.getIteratorTypesArray();
476 LDBG() <<
"Start packing: " << linalgOp;
477 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
478 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
483 PackedOperandsDimList listOfPackedOperandsDim;
484 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
487 if (maybeConstant.has_value() && maybeConstant.value() == 0)
490 PackedOperandsDim packedOperandsDims;
491 packedOperandsDims.packedSize = packedSizes[i];
492 FailureOr<SmallVector<std::optional<int64_t>>>
493 maybePackedDimForEachOperand =
495 if (
failed(maybePackedDimForEachOperand))
497 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
498 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
500 LDBG() <<
"++++ After pack size #" << i <<
": " << packedSizes[i];
501 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
502 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
503 LDBG() <<
"packedDimForEachOperand: "
504 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
510 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
512 for (
const auto &operandsList : {inputOperands, initOperands}) {
513 for (
OpOperand *opOperand : operandsList) {
514 int64_t pos = opOperand->getOperandNumber();
515 Value operand = opOperand->get();
517 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
519 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
520 LDBG() <<
"operand: " << operand;
521 LDBG() <<
"innerPos: " << llvm::interleaved(innerPos);
522 LDBG() <<
"innerPackSizes: " << llvm::interleaved(innerPackSizes);
523 if (innerPackSizes.empty()) {
524 inputsAndInits.push_back(operand);
527 Value dest = linalg::PackOp::createDestinationTensor(
528 rewriter, loc, operand, innerPackSizes, innerPos,
530 ShapedType operandType = cast<ShapedType>(operand.
getType());
531 bool areConstantTiles =
535 if (areConstantTiles && operandType.hasStaticShape() &&
536 !linalg::PackOp::requirePaddingValue(
537 operandType.getShape(), innerPos,
538 cast<ShapedType>(dest.
getType()).getShape(), {},
540 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
541 innerPos, innerPackSizes));
547 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
548 packOps.push_back(linalg::PackOp::create(
549 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
551 inputsAndInits.push_back(packOps.back());
557 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
559 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
560 auto packedLinalgOp =
561 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.
getTypes(),
562 inputs, inits, indexingMaps, iteratorTypes);
563 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
566 for (
OpResult result : packedLinalgOp->getResults()) {
567 int64_t resultNum = result.getResultNumber();
568 linalg::PackOp maybePackedInit =
569 inits[resultNum].getDefiningOp<linalg::PackOp>();
570 if (!maybePackedInit) {
571 results.push_back(result);
575 unPackOps.push_back(linalg::UnPackOp::create(
576 rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
577 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
578 results.push_back(unPackOps.back());
586 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
615 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
619 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
621 assert(tensorType == transposedValue.
getType() &&
622 "expected tensor type mismatch");
627 llvm::map_range(permutation, [](int64_t i) ->
unsigned {
return i; }));
631 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
635 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
641 auto transposedGenericOp = linalg::GenericOp::create(
645 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
646 operandsRef.take_front(linalgOp.getNumDpsInputs()),
647 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
649 linalgOp.getIteratorTypesArray());
650 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
651 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
653 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
656 FailureOr<PackTransposeResult>
658 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
665 linalg::PackOp transposedPackOp =
666 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
668 if (!packOp.getResult().hasOneUse())
671 OpOperand &packUse = *packOp->getUses().begin();
672 if (packUse.
getOwner() != linalgOp) {
674 linalgOp,
"not a single use by the LinalgOp target");
677 (!linalgOp.isDpsInit(&packUse) ||
678 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
680 "not produced by the LinalgOp target");
686 int64_t numLeadingDims = packOp.getSourceRank();
687 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
691 if (permutation.empty())
692 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
694 if (innerPerm.empty()) {
697 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
699 llvm::append_range(permutation,
700 llvm::map_range(innerPerm, [&](int64_t pos) {
701 return numLeadingDims + pos;
713 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
716 linalg::UnPackOp transposedUnPackOp;
719 transposedLinalgOp->getOpOperand(packUseOperandNumber);
720 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
722 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
723 rewriter, loc, transposedResult, innerPerm, outerPerm);
725 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
729 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
747 FailureOr<PackResult>
752 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
753 assert((mnkPaddedSizesNextMultipleOf.empty() ||
754 mnkPaddedSizesNextMultipleOf.size() == 3) &&
755 "num of packing sizes next multiple should be empty or of size 3");
756 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
759 int64_t numLoops = linalgOp.getNumLoops();
761 LDBG() <<
"need 3+ loops to find a matmul to pack, got " << numLoops
762 <<
" in: " << linalgOp;
764 linalgOp,
"need 3+ loops to find a matmul to pack");
768 int64_t numPackedDims = mnkPackedSizes.size();
770 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
771 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
773 for (int64_t i = 0, e = numPackedDims; i < e; ++i)
774 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
776 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
777 paddedSizesNextMultipleOf[mnkOrder[i]] =
778 mnkPaddedSizesNextMultipleOf.empty() ? 0
779 : mnkPaddedSizesNextMultipleOf[i];
783 FailureOr<ContractionDimensions> maybeDimensions =
785 if (
failed(maybeDimensions)) {
786 LDBG() <<
"couldn't infer matmul iterators in: " << linalgOp;
788 "couldn't infer matmul iterators");
796 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
797 kPos = maybeDimensions->k.back();
798 LDBG() <<
"Start packing generic op greedily with (m@" << mPos <<
", n@"
799 << nPos <<
", k@" << kPos <<
"): " << linalgOp;
802 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
804 FailureOr<GenericOp> generalizeResult =
806 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
807 genericOp = *generalizeResult;
815 LDBG() <<
"perm: " << llvm::interleaved(permutation);
818 FailureOr<GenericOp> interchangeResult =
820 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
821 genericOp = *interchangeResult;
822 LDBG() <<
"Generalized Op to pack: " << genericOp;
839 cast<LinalgOp>(genericOp.getOperation())
840 .createLoopRanges(rewriter, genericOp.getLoc());
844 LDBG() <<
"paddedSizesNextMultipleOf: "
845 << llvm::interleaved(paddedSizesNextMultipleOf);
846 LDBG() <<
"loopRanges: "
847 << llvm::interleaved(
848 llvm::map_range(loopRanges, [](
Range r) {
return r.
size; }));
851 for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
852 if (paddedSizesNextMultipleOf[i] == 0) {
853 adjustedPackedSizes.push_back(packedSizes[i]);
860 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
861 {loopRanges[adjustedPackedSizes.size()].size,
862 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
864 LDBG() <<
"adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
870 return pack(rewriter, genericOp, adjustedPackedSizes);
879 assert(!tileSizeComputationFunction &&
"tile sizes already set");
884 &op->getParentOfType<func::FuncOp>().getBody().front());
885 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
903 auto padValue = padOp.getConstantPaddingValue();
906 if (padValue.getParentBlock() == &padOp.getRegion().front())
908 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
912 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
913 padOp.getResultType(), dynSizes);
916 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
925 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
928 rewriter, padOp.getLoc(),
929 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
933 auto resultType = padOp.getResultType();
937 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
938 if (resultType.isDynamicDim(dim)) {
940 padOp.getSource(), dim));
943 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
945 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
946 dynSizes.push_back(plusHigh);
948 staticSizes.push_back(resultType.getDimSize(dim));
953 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
954 resultType.getElementType(), dynSizes);
958 auto sourceType = padOp.getSourceType();
966 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
974 if (!sliceOp.hasUnitStride())
977 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
981 bool zeroSliceGuard =
true;
983 if (std::optional<bool> control = controlFn(sliceOp))
984 zeroSliceGuard = *control;
989 FailureOr<TilingResult> tilingResult =
991 sliceOp.getMixedSizes(), zeroSliceGuard);
995 RankedTensorType sourceType = sliceOp.getSourceType();
996 RankedTensorType resultType = sliceOp.getResultType();
1000 if (sourceType.getRank() == resultType.getRank()) {
1001 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1007 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1009 rewriter.
replaceOp(sliceOp, rankReduced);
1019 linalg::PackOp packOp) {
1020 Value input = packOp.getSource();
1021 if (!packOp.getPaddingValue()) {
1025 assert(llvm::all_of(packOp.getAllOuterDims(),
1026 [](int64_t val) { return val == 1; }) &&
1027 "some outer dims are != 1");
1030 ShapedType inputType = packOp.getSourceType();
1031 int64_t inputRank = inputType.getRank();
1034 packOp.getDimAndTileMapping();
1041 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1044 if (!tileAndPosMapping.count(dimIdx)) {
1045 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1046 assert(inputDimSize == 1 &&
1047 "with all outer dims == 1, this non-tiled input dim should be 1!");
1048 paddedShape.push_back(inputDimSize);
1055 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1059 if (cstTileSize.has_value()) {
1060 paddedShape.push_back(cstTileSize.value());
1065 paddedShape.push_back(ShapedType::kDynamic);
1068 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1073 false, loc, builder,
1083 constexpr int64_t kNonTiledMarker = -1;
1088 vec, [&](int64_t v) {
return v != kNonTiledMarker; });
1103 int64_t unpackedRank = shape.size();
1104 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1106 innerDims.push_back(dim++);
1111 outerDims.push_back(dim++);
1119 applyPermutationToVector<int64_t>(innerDims, innerPerm);
1124 rankReducedOuterDimsPerm =
1126 if (!rankReducedOuterDimsPerm.empty())
1127 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1130 perm.append(innerDims);
1137 if (llvm::any_of(packOp.getTiledOuterDims(),
1138 [](int64_t dim) { return dim != 1; })) {
1140 packOp,
"not all outer dimensions of the result are 1s");
1151 static int prev = 0;
1158 if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
1159 packOp.getType().getShape()[dim] != 1))
1166 packOp,
"At least one non-unit and un-tiled outer dim is permuted, "
1167 "this is not supported ATM!");
1174 int64_t srcRank = packOp.getSourceRank();
1175 int64_t destRank = packOp.getDestRank();
1194 for (int64_t i = 0; i < srcRank; i++) {
1204 srcPermForTranspose.push_back(i);
1210 ShapedType inputTy = cast<ShapedType>(input.
getType());
1212 for (int64_t i = 0; i < srcRank; i++) {
1217 if (inputTy.isStaticDim(i))
1218 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(inputTy.getShape()[i]));
1220 shapeForEmptyOp.emplace_back(
1221 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1223 shapeForEmptyOp.append(packOp.getMixedTiles());
1227 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1229 if (auto val = llvm::dyn_cast<Value>(ofr))
1230 return getAsOpFoldResult(val);
1234 LDBG() <<
"Pack permutation: " << packOp;
1235 LDBG() <<
"perm: " << llvm::interleaved(srcPermForTranspose);
1236 LDBG() <<
"Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1238 Value empty = tensor::EmptyOp::create(
1239 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1242 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1243 srcPermForTranspose);
1255 for (
auto size : packOp.getAllOuterDims()) {
1259 for (
auto tileSize : packOp.getMixedTiles()) {
1260 auto [_, tileSizeOfr] =
1262 writeSizes.push_back(tileSizeOfr);
1270 auto insert = tensor::InsertSliceOp::create(
1271 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
1272 writeOffsets, writeSizes, writeStrides);
1275 rewriter.
replaceOp(packOp, insert.getResult());
1282 int64_t srcRank = unpackOp.getSourceRank();
1283 int64_t destRank = unpackOp.getDestRank();
1286 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1287 [](int64_t dim) { return dim != 1; })) {
1290 "require the tiled outer dimensions of the result are all 1s");
1296 Value source = unpackOp.getSource();
1298 unpackOp.getDimAndTileMapping();
1321 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1330 if (dimAndTileMapping.count(i)) {
1331 extractSliceSizes.push_back(oneIdxAttr);
1337 if (ShapedType::isDynamic(srcShape[i])) {
1339 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1340 extractSliceSizes.push_back(dynamicDim);
1341 shapeForEmptyOp.push_back(dynamicDim);
1343 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1344 if (srcShape[i] != 1)
1345 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1349 if (srcShape[i] != 1) {
1350 readShapeForExtractSlice.push_back(srcShape[i]);
1355 auto mixedTiles = unpackOp.getMixedTiles();
1356 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1357 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1361 auto tileShape = srcShape.drop_front(destRank);
1363 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1364 Type elemType = unpackOp.getSourceType().getElementType();
1366 Value innerTile = tensor::ExtractSliceOp::create(
1367 rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
1368 extractSliceSizes, extractSliceStrides);
1372 srcShape.take_front(destRank),
innerDimsPos, unpackOp.getOuterDimsPerm());
1375 applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1378 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1380 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1384 int numLoops = shapeForEmptyOp.size();
1389 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1390 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1391 tileSizes.push_back(
1396 tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
1397 tileOffsets, tileSizes, tileStrides);
1403 for (
int i = 0, idx = 0; i < destRank; ++i) {
1404 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1405 writeSizes.push_back(tileSizes[idx++]);
1407 writeSizes.push_back(oneIdxAttr);
1409 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1410 unpackOp.getDest(), writeOffsets,
1411 writeSizes, writeStrides);
1412 rewriter.
replaceOp(unpackOp, insert.getResult());
1425 template <
typename Conv2DOp,
typename Conv1DOp>
1428 if (convOp.hasPureBufferSemantics())
1431 Value input = convOp.getInputs().front();
1432 Value kernel = convOp.getInputs().back();
1433 Value output = convOp.getOutputs().front();
1435 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1436 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1437 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1439 auto kernelShape = kernelType.getShape();
1440 auto outputShape = outputType.getShape();
1443 auto [khIndex, kwIndex, ohIndex, owIndex] =
1446 .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1447 return std::make_tuple(0, 1, 1, 2);
1449 .Case([&](linalg::Conv2DNchwFchwOp op) {
1450 return std::make_tuple(2, 3, 2, 3);
1452 .Case([&](linalg::PoolingNhwcSumOp op) {
1453 return std::make_tuple(0, 1, 1, 2);
1455 .Case([&](linalg::PoolingNchwSumOp op) {
1456 return std::make_tuple(0, 1, 2, 3);
1458 .Case([&](linalg::PoolingNhwcMaxOp op) {
1459 return std::make_tuple(0, 1, 1, 2);
1461 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1462 return std::make_tuple(0, 1, 1, 2);
1464 .Case([&](linalg::PoolingNhwcMinOp op) {
1465 return std::make_tuple(0, 1, 1, 2);
1467 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1468 return std::make_tuple(0, 1, 1, 2);
1470 .Case([&](linalg::PoolingNchwMaxOp op) {
1471 return std::make_tuple(0, 1, 2, 3);
1473 .DefaultUnreachable(
"unexpected conv2d/pool2d operation.");
1477 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1478 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1479 bool removeH = (khSize == 1 && ohSize == 1);
1480 bool removeW = (kwSize == 1 && owSize == 1);
1481 if (!removeH && !removeW)
1487 RankedTensorType newInputType =
1488 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1489 RankedTensorType newKernelType =
1490 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1491 RankedTensorType newOutputType =
1492 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1497 rewriter, loc, input, newInputType);
1499 rewriter, loc, kernel, newKernelType);
1501 rewriter, loc, output, newOutputType);
1506 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1507 strides.erase(strides.begin() + (removeH ? 0 : 1));
1511 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1512 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1515 auto conv1DOp = Conv1DOp::create(
1516 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1517 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1521 rewriter, loc, conv1DOp.getResult(0), output);
1538 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1542 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1546 FailureOr<DepthwiseConv1DNwcWcOp>
1549 if (convOp.hasPureBufferSemantics())
1552 Value input = convOp.getInputs().front();
1553 Value kernel = convOp.getInputs().back();
1554 Value output = convOp.getOutputs().front();
1556 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1557 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1558 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1560 auto kernelShape = kernelType.getShape();
1561 auto outputShape = outputType.getShape();
1565 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1566 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1567 bool removeH = (khSize == 1 && ohSize == 1);
1568 bool removeW = (kwSize == 1 && owSize == 1);
1569 if (!removeH && !removeW)
1575 RankedTensorType newInputType =
1576 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1577 RankedTensorType newKernelType =
1578 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1579 RankedTensorType newOutputType =
1580 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1585 rewriter, loc, input, newInputType);
1587 rewriter, loc, kernel, newKernelType);
1589 rewriter, loc, output, newOutputType);
1593 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1594 strides.erase(strides.begin() + (removeH ? 0 : 1));
1598 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1599 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1602 auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
1603 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1604 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1608 rewriter, loc, conv1DOp.getResult(0), output);
1617 if (convOp.hasPureBufferSemantics())
1620 Value input = convOp.getInputs().front();
1621 Value kernel = convOp.getInputs().back();
1622 Value output = convOp.getOutputs().front();
1624 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1625 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1626 auto outputType = dyn_cast<RankedTensorType>(output.getType());
1628 auto kernelShape = kernelType.getShape();
1629 auto outputShape = outputType.getShape();
1633 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1634 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1635 bool removeH = (khSize == 1 && ohSize == 1);
1636 bool removeW = (kwSize == 1 && owSize == 1);
1637 if (!removeH && !removeW)
1643 RankedTensorType newInputType =
1644 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1645 RankedTensorType newKernelType =
1646 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1647 RankedTensorType newOutputType =
1648 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1653 rewriter, loc, input, newInputType);
1655 rewriter, loc, kernel, newKernelType);
1657 rewriter, loc, output, newOutputType);
1660 Conv1DOp::create(rewriter, loc, newOutputType,
1665 rewriter, loc, conv1DOp.getResult(0), output);
1684 PoolingNwcMaxUnsignedOp>,
1687 PoolingNwcMinUnsignedOp>,
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Builder & setShape(ArrayRef< int64_t > newShape)
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
ArrayRef< int64_t > ReassociationIndicesRef
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.