26 #include "llvm/Support/Debug.h"
29 #define DEBUG_TYPE "linalg-tiling-interface-impl"
48 Value v = affine::AffineApplyOp::create(b, loc, m, ivs);
58 Block *body = linalgOp.getBlock();
62 if (
auto indexOp = dyn_cast<IndexOp>(&op)) {
63 map.
map(indexOp.getResult(), ivs[indexOp.getDim()]);
73 OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());
75 b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
76 memref::StoreOp::create(b, loc, toStore,
77 linalgOp.getDpsInitOperand(operand.index())->get(),
93 template <
typename LinalgOpTy>
94 struct LinalgOpTilingInterface
95 :
public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
99 LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
100 return concreteOp.getIteratorTypesArray();
108 LinalgOp linalgOp = cast<LinalgOp>(op);
110 linalgOp.createFlatListOfOperandDims(b, loc);
111 AffineMap map = linalgOp.getShapesToLoopsMap();
113 return llvm::to_vector(
115 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
116 b, loc, loopExpr, allShapesSizes);
117 return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)};
122 FailureOr<TilingResult>
129 LinalgOp linalgOp = cast<LinalgOp>(op);
132 b, loc, linalgOp, valuesToTile, offsets, sizes, {},
true);
134 llvm::make_filter_range(
136 [](
Value v) ->
bool {
137 return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
145 Operation *tiledOp =
clone(b, linalgOp, resultTensorTypes, tiledOperands);
156 getMappedOffsetAndSize(LinalgOp linalgOp,
OpBuilder &b,
164 for (
auto [indexingMap, offsets, sizes] :
165 llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
166 for (
auto [resultExpr, offset, size] :
167 llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
168 auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
171 unsigned position = dimExpr.getPosition();
172 auto it = mappedOffsets.find(position);
173 if (it != mappedOffsets.end()) {
176 if (seenOffset != offset || seenSize != size) {
178 llvm::dbgs() <<
"inconsistent iteration space mapping from "
179 "offsets/sizes of operands/results";
184 mappedOffsets[position] = offset;
185 mappedSizes[position] = size;
193 cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
194 mappedOffsetsVec.resize(iterationDomain.size());
195 mappedSizesVec.resize(iterationDomain.size());
197 auto it = mappedOffsets.find(index);
198 if (it != mappedOffsets.end()) {
199 mappedOffsetsVec[index] = it->second;
200 mappedSizesVec[index] = mappedSizes.lookup(index);
203 mappedOffsetsVec[index] = domain.offset;
204 mappedSizesVec[index] = domain.size;
211 LogicalResult getIterationDomainTileFromOperandTiles(
217 auto linalgOp = cast<LinalgOp>(op);
219 std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
222 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNumber) {
223 OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
224 return linalgOp.getMatchingIndexingMap(&opOperand);
226 if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
227 allSizes, iterDomainOffsets,
243 LinalgOp linalgOp = cast<LinalgOp>(op);
248 llvm::to_vector(llvm::map_range(sizes, [&](
OpFoldResult ofr) {
252 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
254 b, loc, outOperand->
get(), sizes,
255 linalgOp.getMatchingIndexingMap(outOperand), offsets,
256 {}, subShapeSizes,
true);
257 resultOffsets = sliceParams.
offsets;
258 resultSizes = sliceParams.
sizes;
262 LogicalResult getIterationDomainTileFromResultTile(
267 auto linalgOp = cast<LinalgOp>(op);
274 linalgOp.getIndexingMapMatchingResult(op->
getResult(resultNumber));
277 "unhandled tiled implementation generation when result is not "
278 "accessed using a permuted projection");
284 getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
285 {allSizes}, iterDomainOffsets, iterDomainSizes);
287 assert(succeeded(status) &&
"unexpected error in offset calculation");
291 FailureOr<TilingResult>
296 if (failed(getIterationDomainTileFromResultTile(
297 op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
300 auto tilingInterfaceOp = cast<TilingInterface>(op);
301 FailureOr<TilingResult> tilingResult =
302 tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
304 if (failed(tilingResult))
307 if (tilingResult->tiledOps.size() != 1)
308 return op->
emitOpError(
"failed to generate tiled implementation");
313 tilingResult->generatedSlices};
318 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
323 if (failed(getIterationDomainTileFromOperandTiles(
324 op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
334 auto linalgOp = cast<LinalgOp>(op);
335 if (!linalgOp.hasPureBufferSemantics())
336 return op->
emitOpError(
"expected operation to have buffer semantics");
339 indexedValues.reserve(linalgOp->getNumOperands());
343 for (
OpOperand &operand : linalgOp->getOpOperands()) {
344 if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
345 indexedValues.push_back(
nullptr);
348 if (linalgOp.isScalar(&operand)) {
349 indexedValues.push_back(operand.get());
353 builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
355 memref::LoadOp::create(builder, linalgOpLoc, operand.get(), indices);
356 indexedValues.push_back(load);
372 if (reductionDim == value) {
384 getPartialResultAffineMaps(LinalgOp linalgOp,
386 auto partialReductionMaps = llvm::map_to_vector(
387 linalgOp.getDpsInitsMutable(), [&](
OpOperand &opOperand) {
388 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
389 for (auto redPos : reductionDims) {
391 map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
392 map.getNumResults());
396 return partialReductionMaps;
399 struct InitSliceInfo {
409 static InitSliceInfo getInitSliceInfoForOuterReduction(
419 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
420 if (reductionDims.contains(dim)) {
421 initOffsets.push_back(zero);
423 initOffsets.push_back(offsets[dim]);
425 initSizes.push_back(sizes[dim]);
429 return {resultShape, initOffsets, initSizes, initStrides};
435 static InitSliceInfo getInitSliceInfoForOuterParallel(
445 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
446 if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
447 initOffsets.push_back(splitReductionIvs[dimPos.value()]);
448 initSizes.push_back(one);
450 initOffsets.push_back(offsets[dim]);
451 initSizes.push_back(sizes[dim]);
452 resultShape.push_back(sizes[dim]);
457 return {staticShapes, initOffsets, initSizes, initStrides};
462 static InitSliceInfo getInitSliceInfo(
MLIRContext *context,
470 return getInitSliceInfoForOuterReduction(context, offsets, sizes,
471 reductionDims, splitReductionIvs,
472 partialReductionMap);
475 "unexpected ReductionTilingStrategy");
476 return getInitSliceInfoForOuterParallel(context, offsets, sizes,
477 reductionDims, splitReductionIvs,
478 partialReductionMap);
483 template <
typename LinalgOpTy>
484 struct LinalgOpPartialReductionInterface
485 :
public PartialReductionOpInterface::ExternalModel<
486 LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
487 FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
490 auto linalgOp = cast<LinalgOp>(op);
493 if (linalgOp.hasPureBufferSemantics())
494 return op->
emitOpError(
"expected operation to have tensor semantics");
497 getPartialResultAffineMaps(linalgOp, reductionDims);
500 for (
auto [initIdx, result, partialMap] :
505 combinerOps.size() != 1)
506 return op->
emitOpError(
"Failed to anaysis the reduction operation.");
510 if (!identity.has_value())
512 "Failed to get an identity value for the reduction operation.");
516 for (
AffineExpr dimExpr : partialMap.getResults()) {
517 auto dim = cast<AffineDimExpr>(dimExpr);
518 partialResultShape.push_back(sizes[dim.getPosition()]);
523 tensor::EmptyOp::create(b, loc, partialResultShape, elType);
524 Value constantOp = arith::ConstantOp::create(b, loc, *identity);
525 auto identityTensor =
526 linalg::FillOp::create(b, loc, constantOp, emptyTensor);
527 inits.push_back(identityTensor.getResult(0));
533 FailureOr<TilingResult>
541 auto linalgOp = cast<LinalgOp>(op);
544 getPartialResultAffineMaps(linalgOp, reductionDims);
549 if (tilingStrategy ==
551 newInitMaps = llvm::to_vector(partialReductionMaps);
553 newInitMaps = llvm::map_to_vector(
554 linalgOp.getDpsInitsMutable(), [&](
OpOperand &opOperand) {
555 return linalgOp.getMatchingIndexingMap(&opOperand);
561 b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {},
true);
563 llvm::make_filter_range(
569 for (
auto [partialReductionMap, valueToTile] :
570 llvm::zip_equal(partialReductionMaps, init)) {
571 InitSliceInfo sliceInfo = getInitSliceInfo(
572 b.
getContext(), tilingStrategy, offsets, sizes, reductionDims,
573 splitReductionIvs, partialReductionMap);
574 auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
576 sliceInfo.resultShape, valueToTileType.getElementType(),
577 valueToTileType.getEncoding());
578 auto sliceOp = tensor::ExtractSliceOp::create(
580 sliceInfo.sizes, sliceInfo.strides);
581 tiledInits.push_back(sliceOp.getResult());
582 generatedSlices.push_back(sliceOp);
587 for (
auto [initOperand, newInitMap] :
588 llvm::zip_equal(linalgOp.getDpsInitsMutable(), newInitMaps)) {
589 int mapIdx = linalgOp.getIndexingMapIndex(&initOperand);
590 newMaps[mapIdx] = newInitMap;
595 linalgOp.getIteratorTypesArray();
596 if (tilingStrategy ==
598 for (
int dim : reductionDims)
599 newIteratorTypes[dim] = utils::IteratorType::parallel;
605 if (tilingStrategy ==
607 auto genericOp = GenericOp::create(b, loc, resultTypes, tiledInputs,
608 tiledInits, newMaps, newIteratorTypes);
611 genericOp.getRegion().begin(), mapping);
612 partialReductionOp = genericOp.getOperation();
615 llvm::append_range(operands, tiledInits);
616 partialReductionOp =
mlir::clone(b, op, resultTypes, operands);
619 {partialReductionOp},
620 llvm::map_to_vector(partialReductionOp->
getResults(),
625 FailureOr<MergeResult>
629 auto linalgOp = cast<LinalgOp>(op);
631 getPartialResultAffineMaps(linalgOp, reductionDims);
637 linalgOp.getDpsInits(), partialReduce, partialReductionMaps)) {
638 unsigned initIdx = idx;
644 for (
auto [resultNum, dimExpr] :
646 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
647 if (llvm::is_contained(reductionDims, dim)) {
648 partialReductionDims.push_back(resultNum);
652 auto reduction = linalg::ReduceOp::create(
653 b, loc, partialResult, init, partialReductionDims,
663 linalg::YieldOp::create(b, loc, clonedReductionOp->
getResult(0));
666 mergeOperations.push_back(reduction);
667 replacements.push_back(reduction->getResult(0));
673 LogicalResult getPartialResultTilePosition(
680 auto linalgOp = cast<LinalgOp>(op);
682 getPartialResultAffineMaps(linalgOp, reductionDims);
683 InitSliceInfo sliceInfo = getInitSliceInfo(
684 b.
getContext(), tilingStrategy, offsets, sizes, reductionDims,
685 splitReductionIvs, partialReductionMaps[resultNumber]);
686 std::swap(resultOffsets, sliceInfo.offsets);
687 std::swap(resultSizes, sliceInfo.sizes);
693 template <
typename OpTy>
696 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
697 "applies to only pack or unpack operations");
699 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
706 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
707 loopBounds[dim].offset = zero;
708 loopBounds[dim].stride = one;
709 loopBounds[dim].size = resultShape[0][dim];
717 if (permutation.empty())
719 applyPermutationToVector<OpFoldResult>(offsets, permutation);
720 applyPermutationToVector<OpFoldResult>(sizes, permutation);
724 :
public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
730 auto packOp = cast<PackOp>(op);
732 packOp.getSourceRank(), utils::IteratorType::parallel);
733 return iteratorTypes;
737 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
740 FailureOr<TilingResult>
744 auto packOp = cast<PackOp>(op);
749 int64_t inputRank = packOp.getSourceRank();
752 applyPermToRange(origOffsets, origSizes,
756 packOp.getDimAndTileMapping();
760 for (
auto dim : llvm::seq<int64_t>(0, inputRank)) {
761 using AV = affine::AffineValueExpr;
762 affine::AffineBuilder ab(b, loc);
766 if (dimAndTileMapping.count(dim)) {
770 auto avOffset = AV(dim0).bind(origOffsets[dim]);
771 auto avSize = AV(dim0).bind(origSizes[dim]);
772 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
773 inputIndices.push_back(ab.mul(avOffset, avTileSize));
774 inputSizes.push_back(ab.mul(avSize, avTileSize));
776 inputIndices.push_back(origOffsets[dim]);
777 inputSizes.push_back(origSizes[dim]);
781 if (packOp.getPaddingValue()) {
783 auto avDimSize = AV(dim0).bind(dimSize);
784 auto avInputIdx = AV(dim1).bind(inputIndices.back());
786 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
794 auto sourceSlice = tensor::ExtractSliceOp::create(
795 b, loc, packOp.getSource(), inputIndices, inputSizes, strides);
796 tiledOperands.push_back(sourceSlice);
803 strides.append(packOp.getDestRank() - inputRank, oneAttr);
804 auto outSlice = tensor::ExtractSliceOp::create(
805 b, loc, packOp.getDest(), outputOffsets, outputSizes, strides);
806 tiledOperands.push_back(outSlice);
808 if (
auto val = packOp.getPaddingValue())
809 tiledOperands.push_back(val);
810 for (
auto tile : packOp.getInnerTiles())
811 tiledOperands.push_back(
tile);
832 auto packOp = cast<PackOp>(op);
833 int64_t inputRank = packOp.getSourceRank();
834 int64_t outputRank = packOp.getDestRank();
836 resultOffsets.assign(offsets.begin(), offsets.end());
837 resultOffsets.append(outputRank - inputRank, zeroAttr);
841 resultSizes.assign(sizes.begin(), sizes.end());
842 for (
auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
843 resultSizes.push_back(outputShape[0][dataTileDim]);
848 FailureOr<TilingResult>
852 auto packOp = cast<PackOp>(op);
853 int64_t numTiles = packOp.getInnerDimsPos().size();
858 for (
auto offset : offsets.take_back(numTiles))
863 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
868 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
869 if (failed(tilingResult))
871 return tilingResult.value();
877 LogicalResult getIterationDomainTileFromOperandTiles(
883 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
885 { llvm::dbgs() <<
"unsupported operands for consumer fusion"; });
891 auto packOp = cast<PackOp>(op);
895 packOp.getDimAndTileMapping();
897 packOp.getDestType().getShape().take_front(packOp.getSourceRank()));
898 if (!packOp.getOuterDimsPerm().empty()) {
900 outerShapeWithoutTranspose,
903 for (
auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
904 if (dimAndTileMapping.count(dim)) {
905 FailureOr<int64_t> cstTileSize =
909 std::optional<int64_t> cstInnerSize =
919 int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
920 int64_t destDimSize = outerShapeWithoutTranspose[dim];
921 bool isTiled = failed(cstTileSize) ||
922 ShapedType::isDynamic(srcDimSize) ||
923 cstTileSize.value() < srcDimSize;
925 outerDimOffsets.push_back(offsets[dim]);
926 if (ShapedType::isStatic(destDimSize)) {
929 outerDimSizes.push_back(
930 b.
createOrFold<tensor::DimOp>(loc, packOp.getDest(), dim));
949 if ((failed(cstTileSize) || !cstInnerSize ||
950 *cstTileSize % *cstInnerSize != 0))
953 using AV = affine::AffineValueExpr;
954 affine::AffineBuilder ab(b, loc);
958 auto avOffset = AV(dim0).bind(offsets[dim]);
959 auto avSize = AV(dim0).bind(sizes[dim]);
960 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
961 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
962 outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
964 outerDimOffsets.push_back(offsets[dim]);
965 outerDimSizes.push_back(sizes[dim]);
968 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
969 resultOffsets = outerDimOffsets;
970 resultSizes = outerDimSizes;
975 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
979 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
981 { llvm ::dbgs() <<
"unhandled operands for consumer fusion"; });
988 auto packOp = cast<PackOp>(op);
991 int64_t inputRank = packOp.getSourceRank();
996 auto sourceSlice = tensor::ExtractSliceOp::create(
997 b, loc, packOp.getSource(), offsets, sizes, strides);
998 tiledOperands.push_back(sourceSlice);
1001 if (failed(getIterationDomainTileFromOperandTiles(
1002 op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
1008 outputOffsets, outputSizes)))
1011 strides.append(packOp.getDestRank() - inputRank, oneAttr);
1012 auto outSlice = tensor::ExtractSliceOp::create(
1013 b, loc, packOp.getDest(), outputOffsets, outputSizes, strides);
1014 tiledOperands.push_back(outSlice);
1016 if (
auto val = packOp.getPaddingValue())
1017 tiledOperands.push_back(val);
1018 for (
auto tile : packOp.getInnerTiles())
1019 tiledOperands.push_back(
tile);
1021 Operation *tiledPackOp = PackOp::create(
1031 struct UnpackTileDimInfo {
1032 bool isAlignedToInnerTileSize;
1042 static UnpackTileDimInfo getUnpackTileDimInfo(
OpBuilder &b, UnPackOp unpackOp,
1046 UnpackTileDimInfo info;
1050 unpackOp.getDimAndTileMapping();
1052 if (!dimAndTileMapping.count(tileDim)) {
1053 info.isAlignedToInnerTileSize =
true;
1054 info.sourceOffset = tileOffset;
1055 info.sourceSize = tileSize;
1056 info.resultOffset = zeroAttr;
1057 info.destExpandedSize = tileSize;
1062 using AV = affine::AffineValueExpr;
1063 affine::AffineBuilder ab(b, loc);
1068 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
1070 info.isAlignedToInnerTileSize =
false;
1075 if (!failed(cstSize) && cstInnerSize) {
1076 if (*cstSize % *cstInnerSize == 0)
1077 info.isAlignedToInnerTileSize =
true;
1081 if (*cstInnerSize == *cstSize) {
1082 auto lhs = AV(dim0).bind(tileOffset);
1083 auto rhs = AV(dim1).bind(innerTileSize);
1084 info.sourceOffset = ab.floor(lhs, rhs);
1085 info.sourceSize = oneAttr;
1086 info.resultOffset = zeroAttr;
1087 info.destExpandedSize = tileSize;
1092 if (info.isAlignedToInnerTileSize) {
1094 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
1095 info.resultOffset = zeroAttr;
1096 info.destExpandedSize = tileSize;
1105 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
1113 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
1118 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
1121 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
1122 AV(dim1).bind(firstCoord.quotient));
1124 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
1125 info.sourceOffset = firstCoord.quotient;
1126 info.resultOffset = firstCoord.remainder;
1135 struct UnPackOpTiling
1136 :
public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> {
1139 auto unpackOp = cast<UnPackOp>(op);
1141 unpackOp.getDestRank(), utils::IteratorType::parallel);
1142 return iteratorTypes;
1146 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
1163 FailureOr<TilingResult>
1167 auto unpackOp = cast<UnPackOp>(op);
1168 int64_t srcRank = unpackOp.getSourceRank();
1169 int64_t destRank = unpackOp.getDestRank();
1170 int64_t numInnerTiles = srcRank - destRank;
1176 bool isPerfectTilingCase =
true;
1181 for (
auto dim : llvm::seq<int64_t>(0, destRank)) {
1182 UnpackTileDimInfo info =
1183 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
1184 if (!info.isAlignedToInnerTileSize)
1185 isPerfectTilingCase =
false;
1186 sliceSrcIndices.push_back(info.sourceOffset);
1187 sliceSrcSizes.push_back(info.sourceSize);
1188 destExpandedSizes.push_back(info.destExpandedSize);
1189 resultOffsetsFromDest.push_back(info.resultOffset);
1194 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
1195 unpackOp.getOuterDimsPerm());
1197 sliceSrcIndices.append(numInnerTiles, zeroAttr);
1198 sliceSrcSizes.append(unpackOp.getMixedTiles());
1199 sliceSrcStrides.append(numInnerTiles, oneAttr);
1201 tensor::ExtractSliceOp sliceSource = tensor::ExtractSliceOp::create(
1202 b, loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes,
1204 generatedSlices.push_back(sliceSource);
1208 if (isPerfectTilingCase) {
1209 auto destSliceOp = tensor::ExtractSliceOp::create(
1210 b, loc, unpackOp.getDest(), offsets, sizes, destStrides);
1211 sliceDest = destSliceOp;
1212 generatedSlices.push_back(destSliceOp);
1214 sliceDest = tensor::EmptyOp::create(
1215 b, loc, destExpandedSizes, unpackOp.getDestType().getElementType());
1219 for (
auto tile : unpackOp.getInnerTiles())
1220 tiledOperands.push_back(
tile);
1222 Operation *tiledUnpackOp = UnPackOp::create(
1225 if (isPerfectTilingCase)
1230 auto extractSlice = tensor::ExtractSliceOp::create(
1231 b, loc, tiledUnpackOp->
getResult(0), resultOffsetsFromDest, sizes,
1234 {tiledUnpackOp}, {extractSlice.
getResult()}, generatedSlices};
1243 resultOffsets = llvm::to_vector(offsets);
1244 resultSizes = llvm::to_vector(sizes);
1248 FailureOr<TilingResult>
1252 FailureOr<TilingResult> tilingResult =
1254 if (failed(tilingResult))
1256 return tilingResult.value();
1261 LogicalResult getIterationDomainTileFromOperandTiles(
1267 if (operandNumbers.size() != 1) {
1268 LLVM_DEBUG({ llvm::dbgs() <<
"unable to handle multiple operands"; });
1271 auto unPackOp = cast<UnPackOp>(op);
1272 unsigned operandNumber = operandNumbers[0];
1277 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
1278 resultOffsets = llvm::to_vector(offsets);
1279 resultSizes = llvm::to_vector(sizes);
1284 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1285 auto destOffsets = offsets.drop_back(numTiles);
1286 auto destSizes = sizes.drop_back(numTiles);
1289 int64_t outputRank = unPackOp.getDestRank();
1296 applyPermToRange(origOffsets, origSizes,
1300 unPackOp.getDimAndTileMapping();
1302 for (
auto dim : llvm::seq<int64_t>(0, outputRank)) {
1303 using AV = affine::AffineValueExpr;
1304 affine::AffineBuilder ab(b, loc);
1308 if (dimAndTileMapping.count(dim)) {
1312 auto avOffset = AV(dim0).bind(origOffsets[dim]);
1313 auto avSize = AV(dim0).bind(origSizes[dim]);
1314 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
1315 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
1316 resultOffsets.push_back(ab.mul(avOffset, avTileSize));
1317 auto avResultOffset = AV(dim1).bind(resultOffsets.back());
1318 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
1319 ab.sub(avResultSize, avResultOffset)}));
1321 resultOffsets.push_back(origOffsets[dim]);
1322 resultSizes.push_back(origSizes[dim]);
1329 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
1333 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1334 LLVM_DEBUG({ llvm::dbgs() <<
"unhandled operands for consumer fusion"; });
1337 auto unPackOp = cast<UnPackOp>(op);
1343 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1345 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
1355 if (failed(getIterationDomainTileFromOperandTiles(
1356 op, b, operandNumbers, allOffsets, allSizes, outputOffsets,
1361 int64_t outputRank = unPackOp.getDestRank();
1366 auto extractDestSlice = tensor::ExtractSliceOp::create(
1367 b, loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
1368 tiledOperands.push_back(extractDestSlice);
1370 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
1372 auto extractSourceSlice = tensor::ExtractSliceOp::create(
1373 b, loc, unPackOp.getSource(), offsets, sizes, strides);
1374 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
1375 for (
auto tile : unPackOp.getInnerTiles())
1376 tiledOperands.push_back(
tile);
1380 UnPackOp::create(b, loc,
TypeRange{extractDestSlice.getType()},
1386 extractSourceSlice, extractDestSlice})};
1392 template <
typename OpType>
1394 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
1395 OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>(
1400 template <
typename... OpTypes>
1402 (registerOne<OpTypes>(ctx), ...);
1410 registerOne<linalg::GenericOp>(ctx);
1411 linalg::PackOp::attachInterface<PackOpTiling>(*ctx);
1412 linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
1414 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1422 linalg::PackOp::attachInterface<PackOpTiling>(*ctx);
1423 linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static SmallVector< Value > getIndicesForAccess(OpBuilder &b, Location loc, AffineMap indexingMap, ValueRange ivs)
Return the SSA values that represent the data point accessed using a given indexingMap for a given po...
static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, ValueRange ivs, ValueRange argValues)
Method to inline the payload of a linalgOp given the iteration space point and values for the argumen...
static void registerAll(MLIRContext *ctx)
Variadic helper function.
static void registerOne(MLIRContext *ctx)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs)
Create IR to calculate (div lhs, rhs) and (mod lhs, rhs).
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry ®istry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
void registerTilingInterfaceExternalModels(DialectRegistry ®istry)
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
Container for result values of tiling.
SmallVector< Operation * > tiledOps
A struct containg offsets-sizes-strides arguments of the tiled shape.
SmallVector< OpFoldResult > sizes
SmallVector< OpFoldResult > offsets