27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/Support/Debug.h"
31#define DEBUG_TYPE "linalg-tiling-interface-impl"
50 Value v = affine::AffineApplyOp::create(
b, loc, m, ivs);
60 Block *body = linalgOp.getBlock();
64 if (
auto indexOp = dyn_cast<IndexOp>(&op)) {
65 map.
map(indexOp.getResult(), ivs[indexOp.getDim()]);
73 for (
const auto &operand : llvm::enumerate(terminator->
getOperands())) {
75 OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index());
77 b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs);
78 memref::StoreOp::create(
b, loc, toStore,
79 linalgOp.getDpsInitOperand(operand.index())->get(),
95template <
typename LinalgOpTy>
96struct LinalgOpTilingInterface
97 :
public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
100 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op)
const {
101 LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
102 return concreteOp.getIteratorTypesArray();
106 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &
b)
const {
107 OpBuilder::InsertionGuard g(
b);
108 b.setInsertionPoint(op);
109 Location loc = op->
getLoc();
110 LinalgOp linalgOp = cast<LinalgOp>(op);
111 SmallVector<OpFoldResult> allShapesSizes =
112 linalgOp.createFlatListOfOperandDims(
b, loc);
113 AffineMap map = linalgOp.getShapesToLoopsMap();
115 return llvm::map_to_vector(map.
getResults(), [&](AffineExpr loopExpr) {
116 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(b, loc, loopExpr,
118 return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)};
123 FailureOr<TilingResult>
130 LinalgOp linalgOp = cast<LinalgOp>(op);
133 b, loc, linalgOp, valuesToTile, offsets, sizes, {},
true);
135 llvm::make_filter_range(
137 [](
Value v) ->
bool {
138 return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
146 Operation *tiledOp =
clone(
b, linalgOp, resultTensorTypes, tiledOperands);
157 getMappedOffsetAndSize(LinalgOp linalgOp,
OpBuilder &
b,
165 for (
auto [indexingMap, offsets, sizes] :
166 llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
167 for (
auto [resultExpr, offset, size] :
168 llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
169 auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
172 unsigned position = dimExpr.getPosition();
173 auto it = mappedOffsets.find(position);
174 if (it != mappedOffsets.end()) {
177 if (seenOffset != offset || seenSize != size) {
179 llvm::dbgs() <<
"inconsistent iteration space mapping from "
180 "offsets/sizes of operands/results";
185 mappedOffsets[position] = offset;
186 mappedSizes[position] = size;
194 cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(
b);
195 mappedOffsetsVec.resize(iterationDomain.size());
196 mappedSizesVec.resize(iterationDomain.size());
197 for (
auto [
index, domain] : llvm::enumerate(iterationDomain)) {
198 auto it = mappedOffsets.find(
index);
199 if (it != mappedOffsets.end()) {
200 mappedOffsetsVec[
index] = it->second;
201 mappedSizesVec[
index] = mappedSizes.lookup(
index);
204 mappedOffsetsVec[
index] = domain.offset;
205 mappedSizesVec[
index] = domain.size;
212 LogicalResult getIterationDomainTileFromOperandTiles(
218 auto linalgOp = cast<LinalgOp>(op);
221 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNumber) {
222 OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
223 return linalgOp.getMatchingIndexingMap(&opOperand);
225 if (
failed(getMappedOffsetAndSize(linalgOp,
b, indexingMaps, allOffsets,
226 allSizes, iterDomainOffsets,
242 LinalgOp linalgOp = cast<LinalgOp>(op);
251 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
253 b, loc, outOperand->get(), sizes,
254 linalgOp.getMatchingIndexingMap(outOperand), offsets,
255 {}, subShapeSizes,
true);
256 resultOffsets = sliceParams.
offsets;
257 resultSizes = sliceParams.
sizes;
261 LogicalResult getIterationDomainTileFromResultTile(
266 auto linalgOp = cast<LinalgOp>(op);
273 linalgOp.getIndexingMapMatchingResult(op->
getResult(resultNumber));
276 "unhandled tiled implementation generation when result is not "
277 "accessed using a permuted projection");
283 getMappedOffsetAndSize(linalgOp,
b, indexingMap, {allOffsets},
284 {allSizes}, iterDomainOffsets, iterDomainSizes);
286 assert(succeeded(status) &&
"unexpected error in offset calculation");
290 FailureOr<TilingResult>
295 if (
failed(getIterationDomainTileFromResultTile(
296 op,
b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
299 auto tilingInterfaceOp = cast<TilingInterface>(op);
300 FailureOr<TilingResult> tilingResult =
301 tilingInterfaceOp.getTiledImplementation(
b, mappedOffsets, mappedSizes);
306 if (tilingResult->tiledOps.size() != 1)
307 return op->
emitOpError(
"failed to generate tiled implementation");
310 tilingResult->tiledOps,
312 tilingResult->generatedSlices};
317 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
322 if (
failed(getIterationDomainTileFromOperandTiles(
323 op,
b, operandNumbers, allOffsets, allSizes, mappedOffsets,
333 auto linalgOp = cast<LinalgOp>(op);
334 if (!linalgOp.hasPureBufferSemantics())
335 return op->
emitOpError(
"expected operation to have buffer semantics");
338 indexedValues.reserve(linalgOp->getNumOperands());
342 for (
OpOperand &operand : linalgOp->getOpOperands()) {
343 if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
344 indexedValues.push_back(
nullptr);
347 if (linalgOp.isScalar(&operand)) {
348 indexedValues.push_back(operand.get());
352 builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
354 memref::LoadOp::create(builder, linalgOpLoc, operand.get(),
indices);
355 indexedValues.push_back(
load);
362 bool isOpFusableWithConsumerSlice(
Operation *op,
unsigned resultNumber,
369 bool isOpFusableWithProducerSlices(
374 auto linalgOp = cast<LinalgOp>(op);
376 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNumber) {
377 OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
378 return linalgOp.getMatchingIndexingMap(&opOperand);
383 return succeeded(getMappedOffsetAndSize(linalgOp,
b, indexingMaps,
384 allOffsets, allSizes, mappedOffsets,
396 for (
auto [
index, reductionDim] : llvm::enumerate(reductionDims)) {
397 if (reductionDim == value) {
409getPartialResultAffineMaps(LinalgOp linalgOp,
411 auto partialReductionMaps = llvm::map_to_vector(
412 linalgOp.getDpsInitsMutable(), [&](
OpOperand &opOperand) {
413 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
414 for (auto redPos : reductionDims) {
416 map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
417 map.getNumResults());
421 return partialReductionMaps;
424struct InitSliceInfo {
425 SmallVector<int64_t> resultShape;
426 SmallVector<OpFoldResult> offsets;
427 SmallVector<OpFoldResult> sizes;
428 SmallVector<OpFoldResult> strides;
434static InitSliceInfo getInitSliceInfoForOuterReduction(
441 Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
442 Attribute one = IntegerAttr::get(IndexType::get(context), 1);
444 for (
auto [resultIdx, dimExpr] :
445 llvm::enumerate(partialReductionMap.
getResults())) {
446 if (isa<AffineConstantExpr>(dimExpr)) {
449 initOffsets.push_back(zero);
450 initSizes.push_back(initOperandShape[resultIdx]);
453 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
454 if (reductionDims.contains(dim)) {
455 initOffsets.push_back(zero);
457 initOffsets.push_back(offsets[dim]);
459 initSizes.push_back(sizes[dim]);
463 return {resultShape, initOffsets, initSizes, initStrides};
469static InitSliceInfo getInitSliceInfoForOuterParallel(
476 Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
477 Attribute one = IntegerAttr::get(IndexType::get(context), 1);
480 for (
auto [resultIdx, dimExpr] :
481 llvm::enumerate(partialReductionMap.
getResults())) {
482 if (isa<AffineConstantExpr>(dimExpr)) {
485 initOffsets.push_back(zero);
486 initSizes.push_back(initOperandShape[resultIdx]);
487 resultShape.push_back(initOperandShape[resultIdx]);
490 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
491 if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
492 initOffsets.push_back(splitReductionIvs[dimPos.value()]);
493 initSizes.push_back(one);
495 initOffsets.push_back(offsets[dim]);
496 initSizes.push_back(sizes[dim]);
497 resultShape.push_back(sizes[dim]);
502 return {staticShapes, initOffsets, initSizes, initStrides};
507static InitSliceInfo getInitSliceInfo(
MLIRContext *context,
516 return getInitSliceInfoForOuterReduction(
517 context, offsets, sizes, reductionDims, splitReductionIvs,
518 partialReductionMap, initOperandShape);
521 "unexpected ReductionTilingStrategy");
522 return getInitSliceInfoForOuterParallel(
523 context, offsets, sizes, reductionDims, splitReductionIvs,
524 partialReductionMap, initOperandShape);
529template <
typename LinalgOpTy>
530struct LinalgOpPartialReductionInterface
531 :
public PartialReductionOpInterface::ExternalModel<
532 LinalgOpPartialReductionInterface<LinalgOpTy>, LinalgOpTy> {
533 FailureOr<SmallVector<Value>> generateInitialTensorForPartialReduction(
534 Operation *op, OpBuilder &
b, Location loc, ArrayRef<OpFoldResult> sizes,
536 auto linalgOp = cast<LinalgOp>(op);
538 OpBuilder::InsertionGuard guard(
b);
539 if (linalgOp.hasPureBufferSemantics())
540 return op->
emitOpError(
"expected operation to have tensor semantics");
542 SmallVector<AffineMap> partialResultMaps =
543 getPartialResultAffineMaps(linalgOp, reductionDims);
545 SmallVector<Value> inits;
546 for (
auto [initIdx,
result, partialMap] :
547 llvm::enumerate(linalgOp->getResults(), partialResultMaps)) {
548 SmallVector<Operation *, 4> combinerOps;
551 combinerOps.size() != 1)
552 return op->
emitOpError(
"Failed to anaysis the reduction operation.");
554 Operation *reductionOp = combinerOps[0];
555 std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
556 if (!identity.has_value())
558 "Failed to get an identity value for the reduction operation.");
561 SmallVector<OpFoldResult> partialResultShape;
562 Value initValue = linalgOp.getDpsInits()[initIdx];
563 SmallVector<OpFoldResult> initShape =
565 for (
auto [resultIdx, dimExpr] :
566 llvm::enumerate(partialMap.getResults())) {
567 if (isa<AffineConstantExpr>(dimExpr)) {
570 partialResultShape.push_back(initShape[resultIdx]);
573 auto dim = cast<AffineDimExpr>(dimExpr);
574 partialResultShape.push_back(sizes[dim.getPosition()]);
579 tensor::EmptyOp::create(
b, loc, partialResultShape, elType);
580 Value constantOp = arith::ConstantOp::create(
b, loc, *identity);
581 auto identityTensor =
582 linalg::FillOp::create(
b, loc, constantOp, emptyTensor);
583 inits.push_back(identityTensor.getResult(0));
589 FailureOr<TilingResult>
590 tileToPartialReduction(Operation *op, OpBuilder &
b, Location loc,
592 ValueRange init, ArrayRef<OpFoldResult> offsets,
593 ArrayRef<OpFoldResult> sizes,
595 ArrayRef<OpFoldResult> splitReductionIvs)
const {
596 OpBuilder::InsertionGuard guard(
b);
597 auto linalgOp = cast<LinalgOp>(op);
599 SmallVector<AffineMap> partialReductionMaps =
600 getPartialResultAffineMaps(linalgOp, reductionDims);
604 SmallVector<AffineMap> newInitMaps;
605 if (tilingStrategy ==
606 ReductionTilingStrategy::PartialReductionOuterReduction) {
607 newInitMaps = llvm::to_vector(partialReductionMaps);
609 newInitMaps = llvm::map_to_vector(
610 linalgOp.getDpsInitsMutable(), [&](OpOperand &opOperand) {
611 return linalgOp.getMatchingIndexingMap(&opOperand);
617 b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {},
true);
618 SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
619 llvm::make_filter_range(
620 tiledInputs, [](Value v) ->
bool {
return v.
getDefiningOp(); }),
624 SmallVector<Value, 1> tiledInits;
625 for (
auto [partialReductionMap, valueToTile, initOperandValue] :
626 llvm::zip_equal(partialReductionMaps, init, linalgOp.getDpsInits())) {
629 SmallVector<OpFoldResult> initOperandShape =
631 InitSliceInfo sliceInfo = getInitSliceInfo(
632 b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
633 splitReductionIvs, partialReductionMap, initOperandShape);
634 auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
636 sliceInfo.resultShape, valueToTileType.getElementType(),
637 valueToTileType.getEncoding());
638 auto sliceOp = tensor::ExtractSliceOp::create(
640 sliceInfo.sizes, sliceInfo.strides);
641 tiledInits.push_back(sliceOp.getResult());
642 generatedSlices.push_back(sliceOp);
646 SmallVector<AffineMap> newMaps = linalgOp.getIndexingMapsArray();
647 for (
auto [initOperand, newInitMap] :
648 llvm::zip_equal(linalgOp.getDpsInitsMutable(), newInitMaps)) {
649 int mapIdx = linalgOp.getIndexingMapIndex(&initOperand);
650 newMaps[mapIdx] = newInitMap;
654 SmallVector<utils::IteratorType> newIteratorTypes =
655 linalgOp.getIteratorTypesArray();
656 if (tilingStrategy ==
657 ReductionTilingStrategy::PartialReductionOuterReduction) {
658 for (
int dim : reductionDims)
659 newIteratorTypes[dim] = utils::IteratorType::parallel;
663 Operation *partialReductionOp;
664 auto resultTypes =
ValueRange(tiledInits).getTypes();
665 if (tilingStrategy ==
666 ReductionTilingStrategy::PartialReductionOuterReduction) {
667 auto genericOp = GenericOp::create(
b, loc, resultTypes, tiledInputs,
668 tiledInits, newMaps, newIteratorTypes);
671 genericOp.getRegion().begin(), mapping);
673 partialReductionOp = genericOp.getOperation();
675 SmallVector<Value> operands = std::move(tiledInputs);
676 llvm::append_range(operands, tiledInits);
677 partialReductionOp =
mlir::clone(
b, op, resultTypes, operands);
681 {partialReductionOp},
682 llvm::map_to_vector(partialReductionOp->
getResults(),
683 [](OpResult r) -> Value { return r; }),
687 FailureOr<MergeResult>
688 mergeReductions(Operation *op, OpBuilder &
b, Location loc,
691 auto linalgOp = cast<LinalgOp>(op);
692 SmallVector<AffineMap> partialReductionMaps =
693 getPartialResultAffineMaps(linalgOp, reductionDims);
696 SmallVector<Operation *> mergeOperations;
697 SmallVector<Value> replacements;
698 for (
auto [idx, init, partialResult, partialMap] : llvm::enumerate(
699 linalgOp.getDpsInits(), partialReduce, partialReductionMaps)) {
700 unsigned initIdx = idx;
705 SmallVector<int64_t> partialReductionDims;
706 for (
auto [resultNum, dimExpr] :
707 llvm::enumerate(partialMap.getResults())) {
708 if (isa<AffineConstantExpr>(dimExpr))
710 unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
711 if (llvm::is_contained(reductionDims, dim)) {
712 partialReductionDims.push_back(resultNum);
716 auto reduction = linalg::ReduceOp::create(
717 b, loc, partialResult, init, partialReductionDims,
718 [&linalgOp, &initIdx](OpBuilder &
b, Location loc,
ValueRange inputs) {
720 SmallVector<Operation *, 4> combinerOps;
723 Operation *clonedReductionOp =
b.clone(*combinerOps[0]);
727 linalg::YieldOp::create(
b, loc, clonedReductionOp->
getResult(0));
730 mergeOperations.push_back(reduction);
731 replacements.push_back(reduction->getResult(0));
734 return MergeResult{mergeOperations, replacements};
737 LogicalResult getPartialResultTilePosition(
738 Operation *op, OpBuilder &
b,
unsigned resultNumber,
741 ArrayRef<OpFoldResult> splitReductionIvs,
742 SmallVector<OpFoldResult> &resultOffsets,
743 SmallVector<OpFoldResult> &resultSizes)
const {
744 auto linalgOp = cast<LinalgOp>(op);
745 SmallVector<AffineMap> partialReductionMaps =
746 getPartialResultAffineMaps(linalgOp, reductionDims);
749 Value initOperandValue = linalgOp.getDpsInits()[resultNumber];
750 Location loc = op->
getLoc();
751 SmallVector<OpFoldResult> initOperandShape =
753 InitSliceInfo sliceInfo =
754 getInitSliceInfo(
b.getContext(), tilingStrategy, offsets, sizes,
755 reductionDims, splitReductionIvs,
756 partialReductionMaps[resultNumber], initOperandShape);
757 std::swap(resultOffsets, sliceInfo.offsets);
758 std::swap(resultSizes, sliceInfo.sizes);
764template <
typename OpTy>
767 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
768 "applies to only pack or unpack operations");
770 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
775 (
void)op.reifyResultShapes(builder, resultShape);
777 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
778 loopBounds[dim].offset = zero;
779 loopBounds[dim].stride = one;
780 loopBounds[dim].size = resultShape[0][dim];
788 if (permutation.empty())
800 interchangeVector.reserve(dimsPos.size());
809 for (
int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++)
810 dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
814 for (
int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
815 if (dimsAndPosMapping.count(dimsIdx))
816 interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
818 return interchangeVector;
838 for (
auto [idx, val] : llvm::enumerate(interchangeVector))
839 vec[idx + offset] = elements[val + offset];
845static void generatePackOpScalarImplementationBody(PackOp packOp,
860 computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getSourceRank());
861 interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
862 packOp.getSourceRank());
863 if (!dimsToOuterBlock.empty()) {
865 computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getSourceRank());
867 interchange<Value>(interchangedIvs, interchangeVector, 0);
870 packOp.getDimAndTileMapping();
872 size_t pointLoopsOffset = 0;
873 int64_t sourceRank = packOp.getSourceRank();
874 for (
auto dim : llvm::seq<int64_t>(0, sourceRank)) {
875 if (dimAndTileMapping.contains(dim)) {
880 builder, loc, i *
tile +
j,
882 interchangedIvs[dim],
883 interchangedIvs[pointLoopsOffset + packOp.getSourceRank()],
884 dimAndTileMapping[dim]});
885 sourceIndices.push_back(sourceIndex);
888 sourceIndices.push_back(interchangedIvs[dim]);
892 auto createLoad = [&]() ->
Value {
893 return memref::LoadOp::create(
894 builder, loc, packOp.getSource(),
898 if (
auto paddingValue = packOp.getPaddingValue()) {
901 for (
auto dim : llvm::seq<int64_t>(0, sourceRank)) {
904 Value cond = arithBuilder.slt(
908 scalar = scf::IfOp::create(
911 scf::YieldOp::create(
b, l, createLoad());
915 scf::YieldOp::create(
b, l, paddingValue);
919 scalar = createLoad();
922 memref::StoreOp::create(builder, loc, scalar, packOp.getDest(), ivs);
926 :
public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
928 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op)
const {
932 auto packOp = cast<PackOp>(op);
933 SmallVector<utils::IteratorType> iteratorTypes(
934 packOp.getSourceRank(), utils::IteratorType::parallel);
935 return iteratorTypes;
938 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &
b)
const {
939 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op),
b);
942 FailureOr<TilingResult>
944 ArrayRef<OpFoldResult> offsets,
945 ArrayRef<OpFoldResult> sizes)
const {
946 auto packOp = cast<PackOp>(op);
948 if (!packOp.hasPureTensorSemantics())
951 Location loc = packOp.getLoc();
955 int64_t inputRank = packOp.getSourceRank();
956 SmallVector<OpFoldResult> origOffsets(offsets);
957 SmallVector<OpFoldResult> origSizes(sizes);
958 applyPermToRange(origOffsets, origSizes,
962 packOp.getDimAndTileMapping();
963 SmallVector<OpFoldResult> srcDimValues =
965 SmallVector<OpFoldResult> inputIndices, inputSizes;
966 for (
auto dim : llvm::seq<int64_t>(0, inputRank)) {
967 using AV = affine::AffineValueExpr;
968 affine::AffineBuilder ab(
b, loc);
969 AffineExpr dim0, dim1, sym;
972 if (dimAndTileMapping.count(dim)) {
976 auto avOffset = AV(dim0).bind(origOffsets[dim]);
977 auto avSize = AV(dim0).bind(origSizes[dim]);
978 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
979 inputIndices.push_back(ab.mul(avOffset, avTileSize));
980 inputSizes.push_back(ab.mul(avSize, avTileSize));
982 inputIndices.push_back(origOffsets[dim]);
983 inputSizes.push_back(origSizes[dim]);
987 if (packOp.getPaddingValue()) {
988 OpFoldResult dimSize = srcDimValues[dim];
989 auto avDimSize = AV(dim0).bind(dimSize);
990 auto avInputIdx = AV(dim1).bind(inputIndices.back());
992 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
996 auto oneAttr =
b.getI64IntegerAttr(1);
997 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
999 SmallVector<Value> tiledOperands;
1000 auto sourceSlice = tensor::ExtractSliceOp::create(
1001 b, loc, packOp.getSource(), inputIndices, inputSizes, strides);
1002 tiledOperands.push_back(sourceSlice);
1004 SmallVector<OpFoldResult> outputOffsets, outputSizes;
1009 strides.append(packOp.getDestRank() - inputRank, oneAttr);
1010 auto outSlice = tensor::ExtractSliceOp::create(
1011 b, loc, packOp.getDest(), outputOffsets, outputSizes, strides);
1012 tiledOperands.push_back(outSlice);
1014 if (
auto val = packOp.getPaddingValue())
1015 tiledOperands.push_back(val);
1016 for (
auto tile : packOp.getInnerTiles())
1017 tiledOperands.push_back(
tile);
1019 Operation *tiledPackOp = PackOp::create(
1022 return TilingResult{
1024 SmallVector<Value>(tiledPackOp->
getResults()),
1025 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
1030 ArrayRef<OpFoldResult> offsets,
1031 ArrayRef<OpFoldResult> sizes,
1032 SmallVector<OpFoldResult> &resultOffsets,
1033 SmallVector<OpFoldResult> &resultSizes)
const {
1038 auto packOp = cast<PackOp>(op);
1039 int64_t inputRank = packOp.getSourceRank();
1040 int64_t outputRank = packOp.getDestRank();
1041 auto zeroAttr =
b.getI64IntegerAttr(0);
1042 resultOffsets.assign(offsets.begin(), offsets.end());
1043 resultOffsets.append(outputRank - inputRank, zeroAttr);
1047 resultSizes.assign(sizes.begin(), sizes.end());
1048 for (
auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
1049 resultSizes.push_back(outputShape[0][dataTileDim]);
1054 FailureOr<TilingResult>
1055 generateResultTileValue(Operation *op, OpBuilder &
b,
unsigned resultNumber,
1056 ArrayRef<OpFoldResult> offsets,
1057 ArrayRef<OpFoldResult> sizes)
const {
1058 auto packOp = cast<PackOp>(op);
1059 int64_t numTiles = packOp.getInnerDimsPos().size();
1064 for (
auto offset : offsets.take_back(numTiles))
1069 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
1074 op,
b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
1075 if (
failed(tilingResult))
1077 return tilingResult.value();
1080 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
1083 auto packOp = cast<PackOp>(op);
1084 assert(packOp.hasPureBufferSemantics() &&
1085 "expected operation to have buffer semantics");
1086 OpBuilder::InsertionGuard g(builder);
1089 SmallVector<Value> ivVec(ivs);
1092 SmallVector<OpFoldResult> outputShape;
1093 Value dest = packOp.getDest();
1094 for (
auto dim : llvm::seq<int64_t>(0, packOp.getDestRank()))
1103 for (
auto dataTileDim : llvm::seq<unsigned>(packOp.getSourceRank(),
1104 packOp.getDestRank() - 1)) {
1106 outputShape[dataTileDim]);
1107 scf::ForOp loop = scf::ForOp::create(builder, loc, zero, ub, one);
1109 ivVec.push_back(loop.getInductionVar());
1116 [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
1118 ivVec.push_back(iv);
1119 generatePackOpScalarImplementationBody(packOp, bodyBuilder, bodyLoc,
1121 scf::YieldOp::create(bodyBuilder, bodyLoc);
1129 LogicalResult getIterationDomainTileFromOperandTiles(
1130 Operation *op, OpBuilder &
b, ArrayRef<unsigned> operandNumbers,
1131 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1132 ArrayRef<SmallVector<OpFoldResult>> allSizes,
1133 SmallVectorImpl<OpFoldResult> &resultOffsets,
1134 SmallVectorImpl<OpFoldResult> &resultSizes)
const {
1135 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1137 { llvm::dbgs() <<
"unsupported operands for consumer fusion"; });
1141 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1142 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1143 auto packOp = cast<PackOp>(op);
1144 Location loc = packOp.getLoc();
1145 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
1147 packOp.getDimAndTileMapping();
1148 SmallVector<int64_t> outerShapeWithoutTranspose(
1149 packOp.getDestType().getShape().take_front(packOp.getSourceRank()));
1150 if (!packOp.getOuterDimsPerm().empty()) {
1152 outerShapeWithoutTranspose,
1155 for (
auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
1156 if (dimAndTileMapping.count(dim)) {
1157 FailureOr<int64_t> cstTileSize =
1159 presburger::BoundType::UB, sizes[dim],
1161 std::optional<int64_t> cstInnerSize =
1171 int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
1172 int64_t destDimSize = outerShapeWithoutTranspose[dim];
1174 ShapedType::isDynamic(srcDimSize) ||
1175 cstTileSize.value() < srcDimSize;
1177 outerDimOffsets.push_back(offsets[dim]);
1178 if (ShapedType::isStatic(destDimSize)) {
1179 outerDimSizes.push_back(
b.getIndexAttr(destDimSize));
1181 outerDimSizes.push_back(
1182 b.createOrFold<tensor::DimOp>(loc, packOp.getDest(), dim));
1201 if ((
failed(cstTileSize) || !cstInnerSize ||
1202 *cstTileSize % *cstInnerSize != 0))
1205 using AV = affine::AffineValueExpr;
1206 affine::AffineBuilder ab(
b, loc);
1207 AffineExpr dim0, sym;
1210 auto avOffset = AV(dim0).bind(offsets[dim]);
1211 auto avSize = AV(dim0).bind(sizes[dim]);
1212 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
1213 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
1214 outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
1216 outerDimOffsets.push_back(offsets[dim]);
1217 outerDimSizes.push_back(sizes[dim]);
1220 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
1221 resultOffsets = outerDimOffsets;
1222 resultSizes = outerDimSizes;
1227 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
1228 Operation *op, OpBuilder &
b, ArrayRef<unsigned> operandNumbers,
1229 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1230 ArrayRef<SmallVector<OpFoldResult>> allSizes)
const {
1231 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1233 { llvm ::dbgs() <<
"unhandled operands for consumer fusion"; });
1237 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1238 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1240 auto packOp = cast<PackOp>(op);
1242 if (!packOp.hasPureTensorSemantics())
1245 Location loc = packOp.getLoc();
1247 int64_t inputRank = packOp.getSourceRank();
1248 auto oneAttr =
b.getI64IntegerAttr(1);
1249 SmallVector<OpFoldResult> strides(inputRank, oneAttr);
1251 SmallVector<Value> tiledOperands;
1252 auto sourceSlice = tensor::ExtractSliceOp::create(
1253 b, loc, packOp.getSource(), offsets, sizes, strides);
1254 tiledOperands.push_back(sourceSlice);
1256 SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
1257 if (
failed(getIterationDomainTileFromOperandTiles(
1258 op,
b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
1262 SmallVector<OpFoldResult> outputOffsets, outputSizes;
1264 outputOffsets, outputSizes)))
1267 strides.append(packOp.getDestRank() - inputRank, oneAttr);
1268 auto outSlice = tensor::ExtractSliceOp::create(
1269 b, loc, packOp.getDest(), outputOffsets, outputSizes, strides);
1270 tiledOperands.push_back(outSlice);
1272 if (
auto val = packOp.getPaddingValue())
1273 tiledOperands.push_back(val);
1274 for (
auto tile : packOp.getInnerTiles())
1275 tiledOperands.push_back(
tile);
1277 Operation *tiledPackOp = PackOp::create(
1280 return TilingResult{
1282 SmallVector<Value>(tiledPackOp->
getResults()),
1283 llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
1287struct UnpackTileDimInfo {
1288 bool isAlignedToInnerTileSize;
1289 OpFoldResult sourceOffset;
1290 OpFoldResult sourceSize;
1291 OpFoldResult resultOffset;
1292 OpFoldResult destExpandedSize;
1298static UnpackTileDimInfo getUnpackTileDimInfo(
OpBuilder &
b, UnPackOp unpackOp,
1302 UnpackTileDimInfo info;
1306 unpackOp.getDimAndTileMapping();
1308 if (!dimAndTileMapping.count(tileDim)) {
1309 info.isAlignedToInnerTileSize =
true;
1310 info.sourceOffset = tileOffset;
1311 info.sourceSize = tileSize;
1312 info.resultOffset = zeroAttr;
1313 info.destExpandedSize = tileSize;
1324 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
1326 info.isAlignedToInnerTileSize =
false;
1331 if (!
failed(cstSize) && cstInnerSize) {
1332 if (*cstSize % *cstInnerSize == 0)
1333 info.isAlignedToInnerTileSize =
true;
1337 if (*cstInnerSize == *cstSize) {
1338 auto lhs = AV(dim0).bind(tileOffset);
1339 auto rhs = AV(dim1).bind(innerTileSize);
1340 info.sourceOffset = ab.floor(
lhs,
rhs);
1341 info.sourceSize = oneAttr;
1342 info.resultOffset = zeroAttr;
1343 info.destExpandedSize = tileSize;
1348 if (info.isAlignedToInnerTileSize) {
1350 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
1351 info.resultOffset = zeroAttr;
1352 info.destExpandedSize = tileSize;
1361 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
1365 affine::DivModValue firstCoord = affine::getDivMod(
1369 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
1370 affine::DivModValue lastCoord = affine::getDivMod(
1374 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
1377 OpFoldResult lengthMinusOne = ab.sub(AV(dim0).bind(lastCoord.quotient),
1378 AV(dim1).bind(firstCoord.quotient));
1380 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
1381 info.sourceOffset = firstCoord.quotient;
1382 info.resultOffset = firstCoord.remainder;
1385 info.destExpandedSize =
b.createOrFold<arith::MulIOp>(
1391struct UnPackOpTiling
1392 :
public TilingInterface::ExternalModel<UnPackOpTiling, linalg::UnPackOp> {
1394 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op)
const {
1395 auto unpackOp = cast<UnPackOp>(op);
1396 SmallVector<utils::IteratorType> iteratorTypes(
1397 unpackOp.getDestRank(), utils::IteratorType::parallel);
1398 return iteratorTypes;
1401 SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &
b)
const {
1402 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op),
b);
1419 FailureOr<TilingResult>
1421 ArrayRef<OpFoldResult> offsets,
1422 ArrayRef<OpFoldResult> sizes)
const {
1423 auto unpackOp = cast<UnPackOp>(op);
1425 if (!unpackOp.hasPureTensorSemantics())
1428 int64_t srcRank = unpackOp.getSourceRank();
1429 int64_t destRank = unpackOp.getDestRank();
1430 int64_t numInnerTiles = srcRank - destRank;
1431 Location loc = unpackOp.getLoc();
1436 bool isPerfectTilingCase =
true;
1437 Attribute oneAttr =
b.getIndexAttr(1);
1438 SmallVector<OpFoldResult> sliceSrcStrides(destRank, oneAttr);
1439 SmallVector<OpFoldResult> sliceSrcIndices, sliceSrcSizes;
1440 SmallVector<OpFoldResult> destExpandedSizes, resultOffsetsFromDest;
1441 for (
auto dim : llvm::seq<int64_t>(0, destRank)) {
1442 UnpackTileDimInfo info =
1443 getUnpackTileDimInfo(
b, unpackOp, dim, offsets[dim], sizes[dim]);
1444 if (!info.isAlignedToInnerTileSize)
1445 isPerfectTilingCase =
false;
1446 sliceSrcIndices.push_back(info.sourceOffset);
1447 sliceSrcSizes.push_back(info.sourceSize);
1448 destExpandedSizes.push_back(info.destExpandedSize);
1449 resultOffsetsFromDest.push_back(info.resultOffset);
1454 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
1455 unpackOp.getOuterDimsPerm());
1456 Attribute zeroAttr =
b.getIndexAttr(0);
1457 sliceSrcIndices.append(numInnerTiles, zeroAttr);
1458 sliceSrcSizes.append(unpackOp.getMixedTiles());
1459 sliceSrcStrides.append(numInnerTiles, oneAttr);
1460 SmallVector<Operation *> generatedSlices;
1461 tensor::ExtractSliceOp sliceSource = tensor::ExtractSliceOp::create(
1462 b, loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes,
1464 generatedSlices.push_back(sliceSource);
1466 SmallVector<OpFoldResult> destStrides(destRank, oneAttr);
1468 if (isPerfectTilingCase) {
1469 auto destSliceOp = tensor::ExtractSliceOp::create(
1470 b, loc, unpackOp.getDest(), offsets, sizes, destStrides);
1471 sliceDest = destSliceOp;
1472 generatedSlices.push_back(destSliceOp);
1474 sliceDest = tensor::EmptyOp::create(
1475 b, loc, destExpandedSizes, unpackOp.getDestType().getElementType());
1478 SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest};
1479 for (
auto tile : unpackOp.getInnerTiles())
1480 tiledOperands.push_back(
tile);
1482 Operation *tiledUnpackOp = UnPackOp::create(
1485 if (isPerfectTilingCase)
1486 return TilingResult{{tiledUnpackOp},
1487 SmallVector<Value>(tiledUnpackOp->
getResults()),
1490 auto extractSlice = tensor::ExtractSliceOp::create(
1491 b, loc, tiledUnpackOp->
getResult(0), resultOffsetsFromDest, sizes,
1493 return TilingResult{
1494 {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices};
1499 ArrayRef<OpFoldResult> offsets,
1500 ArrayRef<OpFoldResult> sizes,
1501 SmallVector<OpFoldResult> &resultOffsets,
1502 SmallVector<OpFoldResult> &resultSizes)
const {
1503 resultOffsets = llvm::to_vector(offsets);
1504 resultSizes = llvm::to_vector(sizes);
1508 FailureOr<TilingResult>
1509 generateResultTileValue(Operation *op, OpBuilder &
b,
unsigned resultNumber,
1510 ArrayRef<OpFoldResult> offsets,
1511 ArrayRef<OpFoldResult> sizes)
const {
1512 FailureOr<TilingResult> tilingResult =
1514 if (
failed(tilingResult))
1516 return tilingResult.value();
1519 LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
1522 auto unpackOp = cast<UnPackOp>(op);
1523 assert(unpackOp.hasPureBufferSemantics() &&
1524 "expected operation to have buffer semantics");
1525 assert(ivs.size() == unpackOp.getDestRank() &&
1526 "number of ivs must match the rank of the output tensor");
1527 OpBuilder::InsertionGuard g(builder);
1530 unpackOp.getDimAndTileMapping();
1532 SmallVector<Value> inputIvs;
1534 SmallVector<Value> inputIvsPointLoops;
1535 inputIvs.reserve(unpackOp.getDestRank());
1536 inputIvsPointLoops.reserve(dimAndTileMapping.size());
1537 for (
auto dim : llvm::seq<int64_t>(0, unpackOp.getDestRank())) {
1538 if (dimAndTileMapping.count(dim)) {
1539 affine::DivModValue divMod =
1540 affine::getDivMod(builder, loc, ivs[dim],
1542 builder, loc, dimAndTileMapping[dim]));
1543 inputIvsPointLoops.push_back(divMod.remainder);
1544 inputIvs.push_back(divMod.quotient);
1546 inputIvs.push_back(ivs[dim]);
1552 assert(inputIvsPointLoops.size() + inputIvs.size() ==
1553 unpackOp.getSourceRank() &&
1554 "expect same number of induction variables equals to input rank");
1556 ArrayRef<int64_t> innerDims = unpackOp.getInnerDimsPos();
1557 SmallVector<int64_t> interchangeVector =
1558 computeInterchangeFromDimPos(innerDims, unpackOp.getDestRank());
1559 SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
1560 interchangedInputIvsPointLoops = interchange<Value>(
1561 interchangedInputIvsPointLoops, interchangeVector, 0);
1564 ArrayRef<int64_t> outerDims = unpackOp.getOuterDimsPerm();
1565 if (!outerDims.empty())
1566 inputIvs = interchange<Value>(inputIvs, outerDims, 0);
1568 llvm::append_range(inputIvs, interchangedInputIvsPointLoops);
1570 memref::LoadOp::create(builder, loc, unpackOp.getSource(), inputIvs);
1571 memref::StoreOp::create(builder, loc, scalar, unpackOp.getDest(), ivs);
1577 LogicalResult getIterationDomainTileFromOperandTiles(
1578 Operation *op, OpBuilder &
b, ArrayRef<unsigned> operandNumbers,
1579 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1580 ArrayRef<SmallVector<OpFoldResult>> allSizes,
1581 SmallVectorImpl<OpFoldResult> &resultOffsets,
1582 SmallVectorImpl<OpFoldResult> &resultSizes)
const {
1583 if (operandNumbers.size() != 1) {
1584 LLVM_DEBUG({ llvm::dbgs() <<
"unable to handle multiple operands"; });
1587 auto unPackOp = cast<UnPackOp>(op);
1588 unsigned operandNumber = operandNumbers[0];
1589 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1590 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1593 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
1594 resultOffsets = llvm::to_vector(offsets);
1595 resultSizes = llvm::to_vector(sizes);
1598 Location loc = unPackOp.getLoc();
1600 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1601 auto destOffsets = offsets.drop_back(numTiles);
1602 auto destSizes = sizes.drop_back(numTiles);
1605 int64_t outputRank = unPackOp.getDestRank();
1609 SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
1610 SmallVector<OpFoldResult> origOffsets(destOffsets);
1611 SmallVector<OpFoldResult> origSizes(destSizes);
1612 applyPermToRange(origOffsets, origSizes,
1616 unPackOp.getDimAndTileMapping();
1618 for (
auto dim : llvm::seq<int64_t>(0, outputRank)) {
1619 using AV = affine::AffineValueExpr;
1620 affine::AffineBuilder ab(
b, loc);
1621 AffineExpr dim0, dim1, sym0;
1624 if (dimAndTileMapping.count(dim)) {
1628 auto avOffset = AV(dim0).bind(origOffsets[dim]);
1629 auto avSize = AV(dim0).bind(origSizes[dim]);
1630 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
1631 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
1632 resultOffsets.push_back(ab.mul(avOffset, avTileSize));
1633 auto avResultOffset = AV(dim1).bind(resultOffsets.back());
1634 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
1635 ab.sub(avResultSize, avResultOffset)}));
1637 resultOffsets.push_back(origOffsets[dim]);
1638 resultSizes.push_back(origSizes[dim]);
1645 FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
1646 Operation *op, OpBuilder &
b, ArrayRef<unsigned> operandNumbers,
1647 ArrayRef<SmallVector<OpFoldResult>> allOffsets,
1648 ArrayRef<SmallVector<OpFoldResult>> allSizes)
const {
1649 if (operandNumbers.size() != 1 || operandNumbers[0] != 0) {
1650 LLVM_DEBUG({ llvm::dbgs() <<
"unhandled operands for consumer fusion"; });
1653 auto unPackOp = cast<UnPackOp>(op);
1655 if (!unPackOp.hasPureTensorSemantics())
1658 ArrayRef<OpFoldResult> offsets(allOffsets[0]);
1659 ArrayRef<OpFoldResult> sizes(allSizes[0]);
1663 int64_t numTiles = unPackOp.getInnerDimsPos().size();
1665 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
1670 Location loc = unPackOp.getLoc();
1674 SmallVector<OpFoldResult> outputOffsets, outputSizes;
1675 if (
failed(getIterationDomainTileFromOperandTiles(
1676 op,
b, operandNumbers, allOffsets, allSizes, outputOffsets,
1680 auto oneAttr =
b.getI64IntegerAttr(1);
1681 int64_t outputRank = unPackOp.getDestRank();
1682 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
1684 SmallVector<Value> tiledOperands;
1686 auto extractDestSlice = tensor::ExtractSliceOp::create(
1687 b, loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
1688 tiledOperands.push_back(extractDestSlice);
1690 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
1692 auto extractSourceSlice = tensor::ExtractSliceOp::create(
1693 b, loc, unPackOp.getSource(), offsets, sizes, strides);
1694 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
1695 for (
auto tile : unPackOp.getInnerTiles())
1696 tiledOperands.push_back(
tile);
1699 Operation *tiledUnPackOp =
1700 UnPackOp::create(
b, loc,
TypeRange{extractDestSlice.getType()},
1703 return TilingResult{{tiledUnPackOp},
1704 SmallVector<Value>(tiledUnPackOp->
getResults()),
1705 llvm::to_vector(ArrayRef<Operation *>{
1706 extractSourceSlice, extractDestSlice})};
1712template <
typename OpType>
1714 OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
1715 OpType::template attachInterface<LinalgOpPartialReductionInterface<OpType>>(
1720template <
typename... OpTypes>
1731 linalg::PackOp::attachInterface<PackOpTiling>(*ctx);
1732 linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
1734#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1742 linalg::PackOp::attachInterface<PackOpTiling>(*ctx);
1743 linalg::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, ValueRange ivs, ValueRange argValues)
Method to inline the payload of a linalgOp given the iteration space point and values for the argumen...
static SmallVector< Value > getIndicesForAccess(OpBuilder &b, Location loc, AffineMap indexingMap, ValueRange ivs)
Return the SSA values that represent the data point accessed using a given indexingMap for a given po...
static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setOperand(unsigned idx, Value value)
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void cloneInto(Region *dest, IRMapping &mapper)
Clone the internal blocks from this region into dest.
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, const StopConditionFn &stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry ®istry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
static void registerOne(MLIRContext *ctx)
static void registerAll(MLIRContext *ctx)
Variadic helper function.
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
void registerTilingInterfaceExternalModels(DialectRegistry ®istry)
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
llvm::SetVector< T, Vector, Set, N > SetVector
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Helper struct to build simple arithmetic quantities with minimal type inference support.
Container for result values of tiling.
Helper struct to build simple AffineValueExprs with minimal type inference support.
A struct containg offsets-sizes-strides arguments of the tiled shape.
SmallVector< OpFoldResult > sizes
SmallVector< OpFoldResult > offsets
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.