30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/DebugLog.h"
33#include "llvm/Support/InterleavedRange.h"
34#include "llvm/Support/raw_ostream.h"
38#define DEBUG_TYPE "linalg-transforms"
57 .Case<scf::ForOp>([&](scf::ForOp forOp) {
58 scf::ForOp partialIteration;
61 return partialIteration->getResults();
62 assert(!partialIteration &&
"expected that loop was not peeled");
63 return forOp->getResults();
72 for (
auto loopOp : loops)
85 if (!e.isFunctionOfDim(dim))
96 return llvm::interleaved(ri,
", ",
"|",
"");
147static FailureOr<SmallVector<std::optional<int64_t>>>
151 int64_t newDim = iteratorTypes.size();
152 iteratorTypes.push_back(iteratorTypes[dim]);
155 indexingMaps.size(), std::nullopt);
157 for (
int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
159 AffineMap map = indexingMaps[operandIdx];
162 assert(map.
getNumDims() == newDim &&
"num dims invariant violation");
170 "num results invariant violation");
172 if (!maybeOperandDimensionToPack.has_value()) {
173 newMaps.push_back(map);
178 if (!isa<AffineDimExpr>(map.
getResult(maybeOperandDimensionToPack.value())))
184 newMaps.push_back(map);
187 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
189 indexingMaps = newMaps;
191 return packedDimPerIndexingMap;
197struct PackedOperandsDim {
198 OpFoldResult packedSize;
199 SmallVector<std::optional<int64_t>> packedDimForEachOperand;
203struct PackedOperandsDimList {
204 void pushBack(PackedOperandsDim &&packedOperandsDims) {
205 spec.emplace_back(packedOperandsDims);
208 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
210 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
213 SmallVector<PackedOperandsDim> spec;
219 linalg::PackOp packOp,
220 bool lowerPadLikeWithInsertSlice) {
222 auto packedTensorType =
223 cast<RankedTensorType>(packOp->getResultTypes().front());
224 if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
227 "non-static shape NYI, needs a more powerful tensor.expand_shape op");
236 PackingMetadata packingMetadata;
250 for (
auto [pos, innerSize] :
251 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
253 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
263 rewriter, loc, map, {outerSize, origSize, innerSize});
265 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
267 packingMetadata.reassociations);
268 Value paddingValue = packOp.getPaddingValue();
270 paddingValue = arith::ConstantOp::create(
274 tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
275 highs, paddingValue,
false);
277 LDBG() <<
"insertPositions: "
278 << llvm::interleaved(packingMetadata.insertPositions);
279 LDBG() <<
"outerPositions: "
280 << llvm::interleaved(packingMetadata.outerPositions);
281 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
282 LDBG() <<
"packedToStripMinedShapePerm: "
283 << llvm::interleaved(packedToStripMinedShapePerm);
284 LDBG() <<
"reassociations: "
285 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
287 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
288 LDBG() <<
"collapsed type: " << collapsed;
290 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
309 auto insertSliceOp = tensor::InsertSliceOp::create(
310 rewriter, loc, padOp, packOp.getDest(),
313 LDBG() <<
"insert_slice op: " << insertSliceOp;
315 rewriter.
replaceOp(packOp, insertSliceOp->getResults());
323 auto expandShapeResultType =
325 auto reshapeOp = tensor::ExpandShapeOp::create(
326 rewriter, loc, expandShapeResultType, padOp.getResult(),
327 packingMetadata.reassociations);
332 auto transposeOp = linalg::TransposeOp::create(
333 rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
335 LDBG() <<
"reshape op: " << reshapeOp;
336 LDBG() <<
"transpPerm: " << llvm::interleaved(transpPerm);
337 LDBG() <<
"transpose op: " << transposeOp;
340 rewriter.
replaceOp(packOp, transposeOp->getResults());
345FailureOr<LowerUnPackOpResult>
347 bool lowerUnpadLikeWithExtractSlice) {
352 RankedTensorType packedTensorType = unPackOp.getSourceType();
353 int64_t packedRank = packedTensorType.getRank();
356 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
357 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
366 auto extractSliceOp = tensor::ExtractSliceOp::create(
367 rewriter, loc, destTensorType, unPackOp.getSource(),
371 rewriter.
replaceOp(unPackOp, extractSliceOp->getResults());
374 nullptr, extractSliceOp};
379 PackingMetadata packingMetadata;
389 RankedTensorType stripMinedTensorType =
391 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
392 stripMinedTensorType, packingMetadata.reassociations);
399 auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
400 stripMinedTensorType.getElementType());
402 linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
403 packedToStripMinedShapePerm);
405 LDBG() <<
"insertPositions: "
406 << llvm::interleaved(packingMetadata.insertPositions);
407 LDBG() <<
"packedShape: " << llvm::interleaved(packedTensorType.getShape());
408 LDBG() <<
"packedToStripMinedShapePerm: "
409 << llvm::interleaved(packedToStripMinedShapePerm);
410 LDBG() <<
"reassociations: "
411 << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
413 LDBG() <<
"stripMinedShape: " << llvm::interleaved(stripMinedShape);
414 LDBG() <<
"collapsed type: " << collapsedType;
417 auto reshapeOp = tensor::CollapseShapeOp::create(
418 rewriter, loc, collapsedType, transposeOp->getResult(0),
419 packingMetadata.reassociations);
422 int64_t destRank = destTensorType.getRank();
423 auto extractSliceOp = tensor::ExtractSliceOp::create(
424 rewriter, loc, destTensorType, reshapeOp->getResult(0),
430 auto copyOp = linalg::CopyOp::create(
431 rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
434 rewriter.
replaceOp(unPackOp, copyOp->getResults());
440PackedOperandsDimList::extractPackedDimsForOperand(
int64_t operandPos) {
442 for (
auto &i : spec) {
443 if (!i.packedDimForEachOperand[operandPos].has_value())
445 res.push_back(i.packedDimForEachOperand[operandPos].value());
450SmallVector<OpFoldResult>
451PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
452 SmallVector<OpFoldResult> res;
453 for (
auto &i : spec) {
454 if (!i.packedDimForEachOperand[operandPos].has_value())
456 res.push_back(i.packedSize);
465 linalg::LinalgOp linalgOp,
467 if (packedSizes.size() != linalgOp.getNumLoops()) {
469 "incorrect number of pack sizes");
475 linalgOp.getIteratorTypesArray();
476 LDBG() <<
"Start packing: " << linalgOp;
477 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
478 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
483 PackedOperandsDimList listOfPackedOperandsDim;
484 for (
int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
487 if (maybeConstant.has_value() && maybeConstant.value() == 0)
490 PackedOperandsDim packedOperandsDims;
491 packedOperandsDims.packedSize = packedSizes[i];
492 FailureOr<SmallVector<std::optional<int64_t>>>
493 maybePackedDimForEachOperand =
495 if (failed(maybePackedDimForEachOperand))
497 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
499 LDBG() <<
"++++ After pack size #" << i <<
": " << packedSizes[i];
500 LDBG() <<
"maps: " << llvm::interleaved(indexingMaps);
501 LDBG() <<
"iterators: " << llvm::interleaved(iteratorTypes);
502 LDBG() <<
"packedDimForEachOperand: "
503 << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
505 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
511 llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
513 for (
const auto &operandsList : {inputOperands, initOperands}) {
514 for (
OpOperand *opOperand : operandsList) {
515 int64_t pos = opOperand->getOperandNumber();
516 Value operand = opOperand->get();
518 listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
520 listOfPackedOperandsDim.extractPackSizesForOperand(pos);
521 LDBG() <<
"operand: " << operand;
522 LDBG() <<
"innerPos: " << llvm::interleaved(innerPos);
523 LDBG() <<
"innerPackSizes: " << llvm::interleaved(innerPackSizes);
524 if (innerPackSizes.empty()) {
525 inputsAndInits.push_back(operand);
528 Value dest = linalg::PackOp::createDestinationTensor(
529 rewriter, loc, operand, innerPackSizes, innerPos,
531 ShapedType operandType = cast<ShapedType>(operand.
getType());
532 bool areConstantTiles =
536 if (areConstantTiles && operandType.hasStaticShape() &&
537 !linalg::PackOp::requirePaddingValue(
538 operandType.getShape(), innerPos,
539 cast<ShapedType>(dest.
getType()).getShape(), {},
541 packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
542 innerPos, innerPackSizes));
548 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
549 packOps.push_back(linalg::PackOp::create(
550 rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
552 inputsAndInits.push_back(packOps.back());
558 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
560 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
561 auto packedLinalgOp =
562 linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.
getTypes(),
563 inputs, inits, indexingMaps, iteratorTypes);
564 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
569 linalg::PackOp maybePackedInit =
570 inits[resultNum].getDefiningOp<linalg::PackOp>();
571 if (!maybePackedInit) {
572 results.push_back(
result);
576 unPackOps.push_back(linalg::UnPackOp::create(
577 rewriter, packedLinalgOp->getLoc(),
result, maybePackedInit.getSource(),
578 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
579 results.push_back(unPackOps.back());
587 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
616 assert(linalgOp == opOperand.
getOwner() &&
"linalg op must own the operand");
620 cast<RankedTensorType>(opOperand.
get().
getType()), permutation);
622 assert(tensorType == transposedValue.
getType() &&
623 "expected tensor type mismatch");
628 llvm::map_range(permutation, [](
int64_t i) ->
unsigned {
return i; }));
632 permutationMap.
compose(linalgOp.getMatchingIndexingMap(&opOperand));
636 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
642 auto transposedGenericOp = linalg::GenericOp::create(
646 operandsRef.drop_front(linalgOp.getNumDpsInputs()).
getTypes(),
647 operandsRef.take_front(linalgOp.getNumDpsInputs()),
648 operandsRef.drop_front(linalgOp.getNumDpsInputs()),
650 linalgOp.getIteratorTypesArray());
651 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
652 rewriter.
replaceOp(linalgOp, transposedGenericOp->getResults());
654 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
657FailureOr<PackTransposeResult>
659 linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
666 linalg::PackOp transposedPackOp =
667 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
669 if (!packOp.getResult().hasOneUse())
672 OpOperand &packUse = *packOp->getUses().begin();
673 if (packUse.
getOwner() != linalgOp) {
675 linalgOp,
"not a single use by the LinalgOp target");
678 (!linalgOp.isDpsInit(&packUse) ||
679 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
681 "not produced by the LinalgOp target");
687 int64_t numLeadingDims = packOp.getSourceRank();
688 int64_t numTrailingDims = packOp.getInnerDimsPos().size();
692 if (permutation.empty())
693 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
695 if (innerPerm.empty()) {
698 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
700 llvm::append_range(permutation,
701 llvm::map_range(innerPerm, [&](
int64_t pos) {
702 return numLeadingDims + pos;
714 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
717 linalg::UnPackOp transposedUnPackOp;
720 transposedLinalgOp->getOpOperand(packUseOperandNumber);
721 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
723 transposedUnPackOp = maybeUnPackOp.createTransposedClone(
724 rewriter, loc, transposedResult, innerPerm, outerPerm);
726 rewriter.
replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
730 rewriter.
replaceOp(packOp, transposedPackOp->getResults());
753 assert(mnkPackedSizes.size() == 3 &&
"unexpected num of packing sizes");
754 assert((mnkPaddedSizesNextMultipleOf.empty() ||
755 mnkPaddedSizesNextMultipleOf.size() == 3) &&
756 "num of packing sizes next multiple should be empty or of size 3");
757 assert(mnkOrder.size() == 3 &&
"unexpected mnkOrder size");
760 int64_t numLoops = linalgOp.getNumLoops();
762 LDBG() <<
"need 3+ loops to find a matmul to pack, got " << numLoops
763 <<
" in: " << linalgOp;
765 linalgOp,
"need 3+ loops to find a matmul to pack");
769 int64_t numPackedDims = mnkPackedSizes.size();
771 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
772 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
774 for (
int64_t i = 0, e = numPackedDims; i < e; ++i)
775 packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
777 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
778 paddedSizesNextMultipleOf[mnkOrder[i]] =
779 mnkPaddedSizesNextMultipleOf.empty() ? 0
780 : mnkPaddedSizesNextMultipleOf[i];
784 FailureOr<ContractionDimensions> maybeDimensions =
786 if (failed(maybeDimensions)) {
787 LDBG() <<
"couldn't infer matmul iterators in: " << linalgOp;
789 "couldn't infer matmul iterators");
797 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
798 kPos = maybeDimensions->k.back();
799 LDBG() <<
"Start packing generic op greedily with (m@" << mPos <<
", n@"
800 << nPos <<
", k@" << kPos <<
"): " << linalgOp;
803 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
805 FailureOr<GenericOp> generalizeResult =
807 assert(succeeded(generalizeResult) &&
"unexpected failure generalizing op");
808 genericOp = *generalizeResult;
816 LDBG() <<
"perm: " << llvm::interleaved(permutation);
819 FailureOr<GenericOp> interchangeResult =
821 assert(succeeded(interchangeResult) &&
"unexpected failure interchanging op");
822 genericOp = *interchangeResult;
823 LDBG() <<
"Generalized Op to pack: " << genericOp;
840 cast<LinalgOp>(genericOp.getOperation())
841 .createLoopRanges(rewriter, genericOp.getLoc());
845 LDBG() <<
"paddedSizesNextMultipleOf: "
846 << llvm::interleaved(paddedSizesNextMultipleOf);
847 LDBG() <<
"loopRanges: "
848 << llvm::interleaved(
849 llvm::map_range(loopRanges, [](
Range r) {
return r.
size; }));
852 for (
int64_t i = 0, e = numPackedDims; i < e; ++i) {
853 if (paddedSizesNextMultipleOf[i] == 0) {
854 adjustedPackedSizes.push_back(packedSizes[i]);
861 rewriter, genericOp->getLoc(), d0.
ceilDiv(s0) * s0,
862 {loopRanges[adjustedPackedSizes.size()].size,
863 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
865 LDBG() <<
"adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
871 return pack(rewriter, genericOp, adjustedPackedSizes);
884 b.setInsertionPointToStart(
885 &op->getParentOfType<func::FuncOp>().getBody().front());
886 return llvm::to_vector<4>(map_range(tileSizes, [&](
int64_t s) {
904 auto padValue = padOp.getConstantPaddingValue();
907 if (padValue.getParentBlock() == &padOp.getRegion().front())
909 return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
913 auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
914 padOp.getResultType(), dynSizes);
917 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
926 if (
auto val = llvm::dyn_cast_if_present<Value>(ofr))
929 rewriter, padOp.getLoc(),
930 cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
934 auto resultType = padOp.getResultType();
938 for (
unsigned dim = 0; dim < resultType.getRank(); ++dim) {
939 if (resultType.isDynamicDim(dim)) {
941 padOp.getSource(), dim));
944 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
946 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
947 dynSizes.push_back(plusHigh);
949 staticSizes.push_back(resultType.getDimSize(dim));
954 tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
955 resultType.getElementType(), dynSizes);
959 auto sourceType = padOp.getSourceType();
967 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
975 if (!sliceOp.hasUnitStride())
978 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
982 bool zeroSliceGuard =
true;
984 if (std::optional<bool> control = controlFn(sliceOp))
985 zeroSliceGuard = *control;
990 FailureOr<TilingResult> tilingResult =
992 sliceOp.getMixedSizes(), zeroSliceGuard);
993 if (failed(tilingResult))
996 RankedTensorType sourceType = sliceOp.getSourceType();
997 RankedTensorType resultType = sliceOp.getResultType();
1001 if (sourceType.getRank() == resultType.getRank()) {
1002 rewriter.
replaceOp(sliceOp, tilingResult->tiledValues);
1008 rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1010 rewriter.
replaceOp(sliceOp, rankReduced);
1020 linalg::PackOp packOp) {
1021 Value input = packOp.getSource();
1022 if (!packOp.getPaddingValue()) {
1026 assert(llvm::all_of(packOp.getAllOuterDims(),
1027 [](
int64_t val) { return val == 1; }) &&
1028 "some outer dims are != 1");
1031 ShapedType inputType = packOp.getSourceType();
1032 int64_t inputRank = inputType.getRank();
1035 packOp.getDimAndTileMapping();
1042 for (
int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1045 if (!tileAndPosMapping.count(dimIdx)) {
1046 int64_t inputDimSize = inputType.getDimSize(dimIdx);
1047 assert(inputDimSize == 1 &&
1048 "with all outer dims == 1, this non-tiled input dim should be 1!");
1049 paddedShape.push_back(inputDimSize);
1056 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1060 if (cstTileSize.has_value()) {
1061 paddedShape.push_back(cstTileSize.value());
1066 paddedShape.push_back(ShapedType::kDynamic);
1069 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1072 RankedTensorType::get(paddedShape, inputType.getElementType());
1074 false, loc, builder,
1082static SmallVector<int64_t>
1084 constexpr int64_t kNonTiledMarker = -1;
1086 for (
auto [
index, value] : llvm::enumerate(perm))
1089 vec, [&](
int64_t v) {
return v != kNonTiledMarker; });
1096static SmallVector<int64_t>
1105 for (
auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1106 if (llvm::is_contained(innerDimsPos, i)) {
1107 innerDims.push_back(dim++);
1112 outerDims.push_back(dim++);
1113 if (!outerDimsPerm.empty())
1114 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1125 rankReducedOuterDimsPerm =
1127 if (!rankReducedOuterDimsPerm.empty())
1131 perm.append(innerDims);
1138 if (llvm::any_of(packOp.getTiledOuterDims(),
1139 [](
int64_t dim) { return dim != 1; })) {
1141 packOp,
"not all outer dimensions of the result are 1s");
1145 auto outerDimsPerm = packOp.getOuterDimsPerm();
1151 if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](
int64_t dim) {
1152 static int prev = 0;
1154 if (llvm::is_contained(innerDimsPos, dim))
1159 if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
1160 packOp.getType().getShape()[dim] != 1))
1167 packOp,
"At least one non-unit and un-tiled outer dim is permuted, "
1168 "this is not supported ATM!");
1173 int64_t srcRank = packOp.getSourceRank();
1192 for (
int64_t i = 0; i < srcRank; i++) {
1200 if (llvm::is_contained(innerDimsPos, i))
1202 srcPermForTranspose.push_back(i);
1204 srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1208 ShapedType inputTy = cast<ShapedType>(input.
getType());
1210 for (
int64_t i = 0; i < srcRank; i++) {
1211 if (llvm::is_contained(innerDimsPos, i)) {
1215 if (inputTy.isStaticDim(i))
1216 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(inputTy.getShape()[i]));
1218 shapeForEmptyOp.emplace_back(
1219 tensor::DimOp::create(rewriter, loc, input, i).getResult());
1221 shapeForEmptyOp.append(packOp.getMixedTiles());
1228 llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1230 if (auto val = llvm::dyn_cast<Value>(ofr))
1231 return getAsOpFoldResult(val);
1235 LDBG() <<
"Pack permutation: " << packOp;
1236 LDBG() <<
"perm: " << llvm::interleaved(srcPermForTranspose);
1237 LDBG() <<
"Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1239 Value empty = tensor::EmptyOp::create(
1240 rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1243 auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1244 srcPermForTranspose);
1256 for (
auto size : packOp.getAllOuterDims()) {
1260 for (
auto tileSize : packOp.getMixedTiles()) {
1261 auto [_, tileSizeOfr] =
1263 writeSizes.push_back(tileSizeOfr);
1266 auto insert = tensor::InsertSliceOp::create(
1267 rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), writeSizes);
1270 rewriter.
replaceOp(packOp, insert.getResult());
1277 int64_t destRank = unpackOp.getDestRank();
1280 if (llvm::any_of(unpackOp.getTiledOuterDims(),
1281 [](
int64_t dim) { return dim != 1; })) {
1284 "require the tiled outer dimensions of the result are all 1s");
1290 Value source = unpackOp.getSource();
1292 unpackOp.getDimAndTileMapping();
1311 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1320 if (dimAndTileMapping.count(i)) {
1321 extractSliceSizes.push_back(oneIdxAttr);
1327 if (ShapedType::isDynamic(srcShape[i])) {
1329 tensor::DimOp::create(rewriter, loc, source, i).getResult();
1330 extractSliceSizes.push_back(dynamicDim);
1331 shapeForEmptyOp.push_back(dynamicDim);
1333 extractSliceSizes.push_back(rewriter.
getIndexAttr(srcShape[i]));
1334 if (srcShape[i] != 1)
1335 shapeForEmptyOp.push_back(rewriter.
getIndexAttr(srcShape[i]));
1339 if (srcShape[i] != 1) {
1340 readShapeForExtractSlice.push_back(srcShape[i]);
1345 auto mixedTiles = unpackOp.getMixedTiles();
1346 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1347 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1351 auto tileShape = srcShape.drop_front(destRank);
1353 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1354 Type elemType = unpackOp.getSourceType().getElementType();
1355 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1356 Value innerTile = tensor::ExtractSliceOp::create(
1357 rewriter, loc, readType, unpackOp.getSource(), extractSliceSizes);
1361 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1367 tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1369 linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1375 for (
auto i : llvm::seq<unsigned>(0, destRank)) {
1376 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1377 tileSizes.push_back(
1382 tensor::ExtractSliceOp::create(rewriter, loc, RankedTensorType(),
1383 transposedOp.getResult()[0], tileSizes);
1387 for (
int i = 0, idx = 0; i < destRank; ++i) {
1388 if (dimAndTileMapping.count(i) || destShape[i] != 1)
1389 writeSizes.push_back(tileSizes[idx++]);
1391 writeSizes.push_back(oneIdxAttr);
1393 auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1394 unpackOp.getDest(), writeSizes);
1395 rewriter.
replaceOp(unpackOp, insert.getResult());
1408template <
typename Conv2DOp,
typename Conv1DOp>
1412 std::optional<DilationsAndStrides> convParams =
1419 if (convOp.hasPureBufferSemantics())
1422 Value input = convOp.getDpsInputs().front();
1423 Value kernel = convOp.getDpsInputs().back();
1424 Value output = convOp.getDpsInits().front();
1426 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1427 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1428 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1430 auto kernelShape = kernelType.getShape();
1431 auto outputShape = outputType.getShape();
1434 int64_t khIndex, kwIndex, ohIndex, owIndex;
1435 if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNhwcHwcfOp> ||
1436 std::is_same_v<Conv2DOp, linalg::PoolingNhwcSumOp> ||
1437 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxOp> ||
1438 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMaxUnsignedOp> ||
1439 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinOp> ||
1440 std::is_same_v<Conv2DOp, linalg::PoolingNhwcMinUnsignedOp>) {
1446 }
else if constexpr (std::is_same_v<Conv2DOp, linalg::Conv2DNchwFchwOp>) {
1452 }
else if constexpr (std::is_same_v<Conv2DOp, linalg::PoolingNchwSumOp> ||
1453 std::is_same_v<Conv2DOp, linalg::PoolingNchwMaxOp>) {
1463 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1464 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1465 bool removeH = (khSize == 1 && ohSize == 1);
1466 bool removeW = (kwSize == 1 && owSize == 1);
1467 if (!removeH && !removeW)
1473 RankedTensorType newInputType =
1474 RTTBuilder(inputType).
dropDim((removeH ? ohIndex : owIndex));
1475 RankedTensorType newKernelType =
1476 RTTBuilder(kernelType).
dropDim((removeH ? khIndex : kwIndex));
1477 RankedTensorType newOutputType =
1478 RTTBuilder(outputType).
dropDim((removeH ? ohIndex : owIndex));
1483 rewriter, loc, input, newInputType);
1485 rewriter, loc, kernel, newKernelType);
1487 rewriter, loc, output, newOutputType);
1491 strides.erase(strides.begin() + (removeH ? 0 : 1));
1494 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1497 auto conv1DOp = Conv1DOp::create(
1498 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1499 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1503 rewriter, loc, conv1DOp.getResult(0), output);
1520 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1524 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1528FailureOr<DepthwiseConv1DNwcWcOp>
1532 std::optional<DilationsAndStrides> convParams =
1539 if (convOp.hasPureBufferSemantics())
1542 Value input = convOp.getDpsInputs().front();
1543 Value kernel = convOp.getDpsInputs().back();
1544 Value output = convOp.getDpsInits().front();
1546 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1547 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1548 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1550 auto kernelShape = kernelType.getShape();
1551 auto outputShape = outputType.getShape();
1555 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1556 int64_t ohSize = outputShape[1], owSize = outputShape[2];
1557 bool removeH = (khSize == 1 && ohSize == 1);
1558 bool removeW = (kwSize == 1 && owSize == 1);
1559 if (!removeH && !removeW)
1565 RankedTensorType newInputType =
1566 RTTBuilder(inputType).
dropDim((removeH ? 1 : 2));
1567 RankedTensorType newKernelType =
1568 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1569 RankedTensorType newOutputType =
1570 RTTBuilder(outputType).
dropDim(removeH ? 1 : 2);
1575 rewriter, loc, input, newInputType);
1577 rewriter, loc, kernel, newKernelType);
1579 rewriter, loc, output, newOutputType);
1583 strides.erase(strides.begin() + (removeH ? 0 : 1));
1586 dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1589 auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
1590 rewriter, loc, newOutputType,
ValueRange{newInput, newKernel},
1591 ValueRange{newOutput}, stridesAttr, dilationsAttr);
1595 rewriter, loc, conv1DOp.getResult(0), output);
1605 std::optional<DilationsAndStrides> convParams =
1610 if (convOp.hasPureBufferSemantics())
1613 Value input = convOp.getDpsInputs().front();
1614 Value kernel = convOp.getDpsInputs().back();
1615 Value output = convOp.getDpsInits().front();
1617 auto inputType = dyn_cast<RankedTensorType>(input.
getType());
1618 auto kernelType = dyn_cast<RankedTensorType>(kernel.
getType());
1619 auto outputType = dyn_cast<RankedTensorType>(output.
getType());
1621 auto kernelShape = kernelType.getShape();
1622 auto outputShape = outputType.getShape();
1626 int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1627 int64_t ohSize = outputShape[0], owSize = outputShape[1];
1628 bool removeH = (khSize == 1 && ohSize == 1);
1629 bool removeW = (kwSize == 1 && owSize == 1);
1630 if (!removeH && !removeW)
1636 RankedTensorType newInputType =
1637 RTTBuilder(inputType).
dropDim((removeH ? 0 : 1));
1638 RankedTensorType newKernelType =
1639 RTTBuilder(kernelType).
dropDim((removeH ? 0 : 1));
1640 RankedTensorType newOutputType =
1641 RTTBuilder(outputType).
dropDim(removeH ? 0 : 1);
1646 rewriter, loc, input, newInputType);
1648 rewriter, loc, kernel, newKernelType);
1650 rewriter, loc, output, newOutputType);
1653 Conv1DOp::create(rewriter, loc, newOutputType,
1658 rewriter, loc, conv1DOp.getResult(0), output);
1677 PoolingNwcMaxUnsignedOp>,
1680 PoolingNwcMinUnsignedOp>,
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
std::optional< DilationsAndStrides > matchConvolutionOpOfType(LinalgOp op)
Given a linalg op this function returns DilationsAndStrides if it is a convolution op of type ConvOpT...
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Rewrites a linalg::PackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
FailureOr< Conv1DOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
FailureOr< Conv1DOp > returningMatchAndRewrite(LinalgOp convOp, PatternRewriter &rewriter) const
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
TileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
Struct to hold the result of a pack call.
Struct to hold the result of a packTranspose call.