28 struct PadOpTiling :
public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
31 auto padOp = cast<PadOp>(op);
33 padOp.getResultType().getRank(), utils::IteratorType::parallel);
45 for (
const auto &ub :
enumerate(reifiedShapes[0]))
46 loopRanges[ub.index()].size = ub.value();
50 FailureOr<TilingResult>
54 FailureOr<TilingResult> result =
58 return result.value();
67 resultOffsets.assign(offsets.begin(), offsets.end());
68 resultSizes.assign(sizes.begin(), sizes.end());
72 LogicalResult getIterationDomainTileFromResultTile(
77 iterDomainOffsets.assign(offsets.begin(), offsets.end());
78 iterDomainSizes.assign(sizes.begin(), sizes.end());
82 FailureOr<TilingResult>
86 return getTiledImplementation(op, b, offsets, sizes);
90 template <
typename OpTy>
93 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
94 "applies to only pack or unpack operations");
96 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
103 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
104 loopBounds[dim].offset = zero;
105 loopBounds[dim].stride = one;
106 loopBounds[dim].size = resultShape[0][dim];
114 if (permutation.empty())
116 applyPermutationToVector<OpFoldResult>(offsets, permutation);
117 applyPermutationToVector<OpFoldResult>(sizes, permutation);
121 :
public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
127 auto packOp = cast<PackOp>(op);
129 packOp.getSourceRank(), utils::IteratorType::parallel);
130 return iteratorTypes;
134 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
137 FailureOr<TilingResult>
141 auto packOp = cast<PackOp>(op);
146 int64_t inputRank = packOp.getSourceRank();
149 applyPermToRange(origOffsets, origSizes,
153 packOp.getDimAndTileMapping();
157 for (
auto dim : llvm::seq<int64_t>(0, inputRank)) {
163 if (dimAndTileMapping.count(dim)) {
167 auto avOffset = AV(dim0).bind(origOffsets[dim]);
168 auto avSize = AV(dim0).bind(origSizes[dim]);
169 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
170 inputIndices.push_back(ab.mul(avOffset, avTileSize));
171 inputSizes.push_back(ab.mul(avSize, avTileSize));
173 inputIndices.push_back(origOffsets[dim]);
174 inputSizes.push_back(origSizes[dim]);
178 if (packOp.getPaddingValue()) {
180 auto avDimSize = AV(dim0).bind(dimSize);
181 auto avInputIdx = AV(dim1).bind(inputIndices.back());
183 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
191 auto sourceSlice = b.
create<ExtractSliceOp>(
192 loc, packOp.getSource(), inputIndices, inputSizes, strides);
193 tiledOperands.push_back(sourceSlice);
196 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
200 strides.append(packOp.getDestRank() - inputRank, oneAttr);
201 auto outSlice = b.
create<ExtractSliceOp>(
202 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
203 tiledOperands.push_back(outSlice);
205 if (
auto val = packOp.getPaddingValue())
206 tiledOperands.push_back(val);
207 for (
auto tile : packOp.getInnerTiles())
208 tiledOperands.push_back(
tile);
229 auto packOp = cast<PackOp>(op);
230 int64_t inputRank = packOp.getSourceRank();
231 int64_t outputRank = packOp.getDestRank();
233 resultOffsets.assign(offsets.begin(), offsets.end());
234 resultOffsets.append(outputRank - inputRank, zeroAttr);
238 resultSizes.assign(sizes.begin(), sizes.end());
239 for (
auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
240 resultSizes.push_back(outputShape[0][dataTileDim]);
245 FailureOr<TilingResult>
249 auto packOp = cast<PackOp>(op);
250 int64_t numTiles = packOp.getInnerDimsPos().size();
255 for (
auto offset : offsets.take_back(numTiles))
260 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
264 FailureOr<TilingResult> tilingResult = getTiledImplementation(
265 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
266 if (failed(tilingResult))
268 return tilingResult.value();
274 LogicalResult getIterationDomainTileFromOperandTile(
279 if (operandNumber != 0)
282 auto packOp = cast<PackOp>(op);
285 if (packOp.getPaddingValue())
292 packOp.getDimAndTileMapping();
293 for (
auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
294 if (dimAndTileMapping.count(dim)) {
295 FailureOr<int64_t> cstSize =
299 std::optional<int64_t> cstInnerSize =
315 if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
324 auto avOffset = AV(dim0).bind(offsets[dim]);
325 auto avSize = AV(dim0).bind(sizes[dim]);
326 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
327 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
328 outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
330 outerDimOffsets.push_back(offsets[dim]);
331 outerDimSizes.push_back(sizes[dim]);
334 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
335 resultOffsets = outerDimOffsets;
336 resultSizes = outerDimSizes;
341 FailureOr<TilingResult> getTiledImplementationFromOperandTile(
344 if (operandNumber != 0)
347 auto packOp = cast<PackOp>(op);
350 int64_t inputRank = packOp.getSourceRank();
355 auto sourceSlice = b.
create<ExtractSliceOp>(loc, packOp.getSource(),
356 offsets, sizes, strides);
357 tiledOperands.push_back(sourceSlice);
360 if (failed(getIterationDomainTileFromOperandTile(
361 op, b, 0, offsets, sizes, outerDimOffsets,
366 if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
367 outputOffsets, outputSizes)))
370 strides.append(packOp.getDestRank() - inputRank, oneAttr);
371 auto outSlice = b.
create<ExtractSliceOp>(
372 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
373 tiledOperands.push_back(outSlice);
375 assert(!packOp.getPaddingValue() &&
"Expect no padding semantic");
376 for (
auto tile : packOp.getInnerTiles())
377 tiledOperands.push_back(
tile);
389 struct UnpackTileDimInfo {
390 bool isAlignedToInnerTileSize;
400 static UnpackTileDimInfo getUnpackTileDimInfo(
OpBuilder &b, UnPackOp unpackOp,
404 UnpackTileDimInfo info;
408 unpackOp.getDimAndTileMapping();
410 if (!dimAndTileMapping.count(tileDim)) {
411 info.isAlignedToInnerTileSize =
true;
412 info.sourceOffset = tileOffset;
413 info.sourceSize = tileSize;
414 info.resultOffset = zeroAttr;
415 info.destExpandedSize = tileSize;
426 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
428 info.isAlignedToInnerTileSize =
false;
433 if (!failed(cstSize) && cstInnerSize) {
434 if (*cstSize % *cstInnerSize == 0)
435 info.isAlignedToInnerTileSize =
true;
439 if (*cstInnerSize == *cstSize) {
440 auto lhs = AV(dim0).bind(tileOffset);
441 auto rhs = AV(dim1).bind(innerTileSize);
442 info.sourceOffset = ab.floor(lhs, rhs);
443 info.sourceSize = oneAttr;
444 info.resultOffset = zeroAttr;
445 info.destExpandedSize = tileSize;
450 if (info.isAlignedToInnerTileSize) {
452 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
453 info.resultOffset = zeroAttr;
454 info.destExpandedSize = tileSize;
463 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
471 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
476 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
480 AV(dim1).bind(firstCoord.
quotient));
482 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
483 info.sourceOffset = firstCoord.
quotient;
484 info.resultOffset = firstCoord.
remainder;
493 struct UnPackOpTiling
494 :
public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
497 auto unpackOp = cast<UnPackOp>(op);
499 unpackOp.getDestRank(), utils::IteratorType::parallel);
500 return iteratorTypes;
504 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
521 FailureOr<TilingResult>
525 auto unpackOp = cast<UnPackOp>(op);
526 int64_t srcRank = unpackOp.getSourceRank();
527 int64_t destRank = unpackOp.getDestRank();
528 int64_t numInnerTiles = srcRank - destRank;
534 bool isPerfectTilingCase =
true;
539 for (
auto dim : llvm::seq<int64_t>(0, destRank)) {
540 UnpackTileDimInfo info =
541 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
542 if (!info.isAlignedToInnerTileSize)
543 isPerfectTilingCase =
false;
544 sliceSrcIndices.push_back(info.sourceOffset);
545 sliceSrcSizes.push_back(info.sourceSize);
546 destExpandedSizes.push_back(info.destExpandedSize);
547 resultOffsetsFromDest.push_back(info.resultOffset);
552 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
553 unpackOp.getOuterDimsPerm());
555 sliceSrcIndices.append(numInnerTiles, zeroAttr);
556 sliceSrcSizes.append(unpackOp.getMixedTiles());
557 sliceSrcStrides.append(numInnerTiles, oneAttr);
559 ExtractSliceOp sliceSource =
560 b.
create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
561 sliceSrcSizes, sliceSrcStrides);
562 generatedSlices.push_back(sliceSource);
566 if (isPerfectTilingCase) {
567 auto destSliceOp = b.
create<ExtractSliceOp>(loc, unpackOp.getDest(),
568 offsets, sizes, destStrides);
569 sliceDest = destSliceOp;
570 generatedSlices.push_back(destSliceOp);
572 sliceDest = b.
create<EmptyOp>(loc, destExpandedSizes,
573 unpackOp.getDestType().getElementType());
577 for (
auto tile : unpackOp.getInnerTiles())
578 tiledOperands.push_back(
tile);
583 if (isPerfectTilingCase)
590 resultOffsetsFromDest, sizes, destStrides);
592 {tiledUnpackOp}, {extractSlice.
getResult()}, generatedSlices};
601 resultOffsets = llvm::to_vector(offsets);
602 resultSizes = llvm::to_vector(sizes);
606 FailureOr<TilingResult>
610 FailureOr<TilingResult> tilingResult =
611 getTiledImplementation(op, b, offsets, sizes);
612 if (failed(tilingResult))
614 return tilingResult.value();
619 LogicalResult getIterationDomainTileFromOperandTile(
624 auto unPackOp = cast<UnPackOp>(op);
626 if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
627 resultOffsets = llvm::to_vector(offsets);
628 resultSizes = llvm::to_vector(sizes);
633 int64_t numTiles = unPackOp.getInnerDimsPos().size();
634 auto destOffsets = offsets.drop_back(numTiles);
635 auto destSizes = sizes.drop_back(numTiles);
638 int64_t outputRank = unPackOp.getDestRank();
645 applyPermToRange(origOffsets, origSizes,
649 unPackOp.getDimAndTileMapping();
651 for (
auto dim : llvm::seq<int64_t>(0, outputRank)) {
657 if (dimAndTileMapping.count(dim)) {
661 auto avOffset = AV(dim0).bind(origOffsets[dim]);
662 auto avSize = AV(dim0).bind(origSizes[dim]);
663 auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
664 auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
665 resultOffsets.push_back(ab.mul(avOffset, avTileSize));
666 auto avResultOffset = AV(dim1).bind(resultOffsets.back());
667 resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
668 ab.sub(avResultSize, avResultOffset)}));
670 resultOffsets.push_back(origOffsets[dim]);
671 resultSizes.push_back(origSizes[dim]);
678 FailureOr<TilingResult> getTiledImplementationFromOperandTile(
681 auto unPackOp = cast<UnPackOp>(op);
684 int64_t numTiles = unPackOp.getInnerDimsPos().size();
686 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
696 if (failed(getIterationDomainTileFromOperandTile(
697 op, b, 0, offsets, sizes, outputOffsets,
702 int64_t outputRank = unPackOp.getDestRank();
707 auto extractDestSlice = b.
create<ExtractSliceOp>(
708 loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
709 tiledOperands.push_back(extractDestSlice);
712 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
714 auto extractSourceSlice = b.
create<ExtractSliceOp>(
715 loc, unPackOp.getSource(), offsets, sizes, strides);
716 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
717 for (
auto tile : unPackOp.getInnerTiles())
718 tiledOperands.push_back(
tile);
728 extractSourceSlice, extractDestSlice})};
738 bool generateZeroSliceGuard) {
740 Value padValue = padOp.getConstantPaddingValue();
775 bool hasZeroLen =
false;
778 Value dynHasZeroLenCond;
780 int64_t rank = padOp.getSourceType().getRank();
781 for (
unsigned dim = 0; dim < rank; ++dim) {
782 auto low = padOp.getMixedLowPad()[dim];
784 auto high = padOp.getMixedHighPad()[dim];
786 auto offset = offsets[dim];
787 auto length = sizes[dim];
796 newLows.push_back(newLow);
811 ?
min(
max(sub(offset, low), zero), srcSize)
812 :
min(offset, srcSize);
813 newOffsets.push_back(newOffset);
835 hasLowPad ?
min(
max(add(sub(offset, low), length), zero), srcSize)
836 :
min(add(offset, length), srcSize);
838 newLengths.push_back(newLength);
844 }
else if (!hasZeroLen) {
846 loc, arith::CmpIPredicate::eq,
851 ? b.
create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
860 hasHighPad ? sub(sub(length, newLength), newLow) : zero;
861 newHighs.push_back(newHigh);
871 RankedTensorType resultType =
876 if (resultType == val.getType())
878 return b.
create<tensor::CastOp>(loc, resultType, val);
884 auto createGenerateOp = [&]() {
886 auto generateOp = b.
create<tensor::GenerateOp>(
887 loc, resultType, dynDims,
889 builder.create<tensor::YieldOp>(gLoc, padValue);
896 auto createPadOfExtractSlice = [&]() {
898 auto newSliceOp = b.
create<tensor::ExtractSliceOp>(
899 loc, padOp.getSource(), newOffsets, newLengths, newStrides);
900 auto newPadOp = b.
create<PadOp>(
901 loc,
Type(), newSliceOp, newLows, newHighs,
907 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
910 return std::make_tuple(newPadOp, newSliceOp);
916 Operation *generateOp = createGenerateOp();
924 if (generateZeroSliceGuard && dynHasZeroLenCond) {
928 auto result = b.
create<scf::IfOp>(
929 loc, dynHasZeroLenCond,
932 thenOp = createGenerateOp();
933 b.create<scf::YieldOp>(loc, castResult(thenOp->
getResult(0)));
937 std::tie(elseOp, sliceOp) = createPadOfExtractSlice();
938 b.create<scf::YieldOp>(loc, castResult(elseOp->
getResult(0)));
944 auto [newPadOp, sliceOp] = createPadOfExtractSlice();
946 {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}};
952 tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
953 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
954 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
961 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
962 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
result_range getResults()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs)
Create IR to calculate (div lhs, rhs) and (mod lhs, rhs).
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
void registerTilingInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers external models for Tiling interface for tensor ops.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry ®istry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for result values of tiling.
Helper struct to build simple AffineValueExprs with minimal type inference support.
Holds the result of (div a, b) and (mod a, b).