27 struct PadOpTiling :
public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
30 auto padOp = cast<PadOp>(op);
32 padOp.getResultType().getRank(), utils::IteratorType::parallel);
44 for (
const auto &ub :
enumerate(reifiedShapes[0]))
45 loopRanges[ub.index()].size = ub.value();
49 FailureOr<TilingResult>
53 FailureOr<TilingResult> result =
57 return result.value();
66 resultOffsets.assign(offsets.begin(), offsets.end());
67 resultSizes.assign(sizes.begin(), sizes.end());
71 LogicalResult getIterationDomainTileFromResultTile(
76 iterDomainOffsets.assign(offsets.begin(), offsets.end());
77 iterDomainSizes.assign(sizes.begin(), sizes.end());
81 FailureOr<TilingResult>
85 return getTiledImplementation(op, b, offsets, sizes);
89 template <
typename OpTy>
92 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
93 "applies to only pack or unpack operations");
95 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
102 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
103 loopBounds[dim].offset = zero;
104 loopBounds[dim].stride = one;
105 loopBounds[dim].size = resultShape[0][dim];
113 if (permutation.empty())
115 applyPermutationToVector<OpFoldResult>(offsets, permutation);
116 applyPermutationToVector<OpFoldResult>(sizes, permutation);
120 :
public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
126 auto packOp = cast<PackOp>(op);
128 packOp.getSourceRank(), utils::IteratorType::parallel);
129 return iteratorTypes;
133 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
136 FailureOr<TilingResult>
140 auto packOp = cast<PackOp>(op);
145 int64_t inputRank = packOp.getSourceRank();
148 applyPermToRange(origOffsets, origSizes,
152 packOp.getDimAndTileMapping();
156 for (
auto dim : llvm::seq<int64_t>(0, inputRank)) {
162 if (dimAndTileMapping.count(dim)) {
166 auto avOffset = AV(dim0).bind(origOffsets[dim]);
167 auto avSize = AV(dim0).bind(origSizes[dim]);
168 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
169 inputIndices.push_back(ab.mul(avOffset, avTileSize));
170 inputSizes.push_back(ab.mul(avSize, avTileSize));
172 inputIndices.push_back(origOffsets[dim]);
173 inputSizes.push_back(origSizes[dim]);
177 if (packOp.getPaddingValue()) {
179 auto avDimSize = AV(dim0).bind(dimSize);
180 auto avInputIdx = AV(dim1).bind(inputIndices.back());
182 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
190 auto sourceSlice = b.
create<ExtractSliceOp>(
191 loc, packOp.getSource(), inputIndices, inputSizes, strides);
192 tiledOperands.push_back(sourceSlice);
195 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
199 strides.append(packOp.getDestRank() - inputRank, oneAttr);
200 auto outSlice = b.
create<ExtractSliceOp>(
201 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
202 tiledOperands.push_back(outSlice);
204 if (
auto val = packOp.getPaddingValue())
205 tiledOperands.push_back(val);
206 for (
auto tile : packOp.getInnerTiles())
207 tiledOperands.push_back(
tile);
228 auto packOp = cast<PackOp>(op);
229 int64_t inputRank = packOp.getSourceRank();
230 int64_t outputRank = packOp.getDestRank();
232 resultOffsets.assign(offsets.begin(), offsets.end());
233 resultOffsets.append(outputRank - inputRank, zeroAttr);
237 resultSizes.assign(sizes.begin(), sizes.end());
238 for (
auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
239 resultSizes.push_back(outputShape[0][dataTileDim]);
244 FailureOr<TilingResult>
248 auto packOp = cast<PackOp>(op);
249 int64_t numTiles = packOp.getInnerDimsPos().size();
254 for (
auto offset : offsets.take_back(numTiles))
259 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
263 FailureOr<TilingResult> tilingResult = getTiledImplementation(
264 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
265 if (failed(tilingResult))
267 return tilingResult.value();
273 LogicalResult getIterationDomainTileFromOperandTile(
278 if (operandNumber != 0)
281 auto packOp = cast<PackOp>(op);
284 if (packOp.getPaddingValue())
291 packOp.getDimAndTileMapping();
292 for (
auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
293 if (dimAndTileMapping.count(dim)) {
294 FailureOr<int64_t> cstSize =
298 std::optional<int64_t> cstInnerSize =
314 if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
323 auto avOffset = AV(dim0).bind(offsets[dim]);
324 auto avSize = AV(dim0).bind(sizes[dim]);
325 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
326 outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
327 outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
329 outerDimOffsets.push_back(offsets[dim]);
330 outerDimSizes.push_back(sizes[dim]);
333 applyPermToRange(outerDimOffsets, outerDimSizes, packOp.getOuterDimsPerm());
334 resultOffsets = outerDimOffsets;
335 resultSizes = outerDimSizes;
340 FailureOr<TilingResult> getTiledImplementationFromOperandTile(
343 if (operandNumber != 0)
346 auto packOp = cast<PackOp>(op);
349 int64_t inputRank = packOp.getSourceRank();
354 auto sourceSlice = b.
create<ExtractSliceOp>(loc, packOp.getSource(),
355 offsets, sizes, strides);
356 tiledOperands.push_back(sourceSlice);
359 if (failed(getIterationDomainTileFromOperandTile(
360 op, b, 0, offsets, sizes, outerDimOffsets,
365 if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
366 outputOffsets, outputSizes)))
369 strides.append(packOp.getDestRank() - inputRank, oneAttr);
370 auto outSlice = b.
create<ExtractSliceOp>(
371 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
372 tiledOperands.push_back(outSlice);
374 assert(!packOp.getPaddingValue() &&
"Expect no padding semantic");
375 for (
auto tile : packOp.getInnerTiles())
376 tiledOperands.push_back(
tile);
388 struct UnpackTileDimInfo {
389 bool isAlignedToInnerTileSize;
399 static UnpackTileDimInfo getUnpackTileDimInfo(
OpBuilder &b, UnPackOp unpackOp,
403 UnpackTileDimInfo info;
407 unpackOp.getDimAndTileMapping();
409 if (!dimAndTileMapping.count(tileDim)) {
410 info.isAlignedToInnerTileSize =
true;
411 info.sourceOffset = tileOffset;
412 info.sourceSize = tileSize;
413 info.resultOffset = zeroAttr;
414 info.destExpandedSize = tileSize;
425 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
427 info.isAlignedToInnerTileSize =
false;
432 if (!failed(cstSize) && cstInnerSize) {
433 if (*cstSize % *cstInnerSize == 0)
434 info.isAlignedToInnerTileSize =
true;
438 if (*cstInnerSize == *cstSize) {
439 auto lhs = AV(dim0).bind(tileOffset);
440 auto rhs = AV(dim1).bind(innerTileSize);
441 info.sourceOffset = ab.floor(lhs, rhs);
442 info.sourceSize = oneAttr;
443 info.resultOffset = zeroAttr;
444 info.destExpandedSize = tileSize;
449 if (info.isAlignedToInnerTileSize) {
451 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
452 info.resultOffset = zeroAttr;
453 info.destExpandedSize = tileSize;
462 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
470 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
475 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
479 AV(dim1).bind(firstCoord.
quotient));
481 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
482 info.sourceOffset = firstCoord.
quotient;
483 info.resultOffset = firstCoord.
remainder;
492 struct UnPackOpTiling
493 :
public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
496 auto unpackOp = cast<UnPackOp>(op);
498 unpackOp.getDestRank(), utils::IteratorType::parallel);
499 return iteratorTypes;
503 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
520 FailureOr<TilingResult>
524 auto unpackOp = cast<UnPackOp>(op);
525 int64_t srcRank = unpackOp.getSourceRank();
526 int64_t destRank = unpackOp.getDestRank();
527 int64_t numInnerTiles = srcRank - destRank;
533 bool isPerfectTilingCase =
true;
538 for (
auto dim : llvm::seq<int64_t>(0, destRank)) {
539 UnpackTileDimInfo info =
540 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
541 if (!info.isAlignedToInnerTileSize)
542 isPerfectTilingCase =
false;
543 sliceSrcIndices.push_back(info.sourceOffset);
544 sliceSrcSizes.push_back(info.sourceSize);
545 destExpandedSizes.push_back(info.destExpandedSize);
546 resultOffsetsFromDest.push_back(info.resultOffset);
551 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
552 unpackOp.getOuterDimsPerm());
554 sliceSrcIndices.append(numInnerTiles, zeroAttr);
555 sliceSrcSizes.append(unpackOp.getMixedTiles());
556 sliceSrcStrides.append(numInnerTiles, oneAttr);
558 b.
create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
559 sliceSrcSizes, sliceSrcStrides);
564 if (isPerfectTilingCase) {
565 auto destSliceOp = b.
create<ExtractSliceOp>(loc, unpackOp.getDest(),
566 offsets, sizes, destStrides);
567 sliceDest = destSliceOp;
568 generatedSlices.push_back(destSliceOp);
570 sliceDest = b.
create<EmptyOp>(loc, destExpandedSizes,
571 unpackOp.getDestType().getElementType());
575 for (
auto tile : unpackOp.getInnerTiles())
576 tiledOperands.push_back(
tile);
581 if (isPerfectTilingCase)
588 resultOffsetsFromDest, sizes, destStrides);
589 generatedSlices.push_back(extractSlice);
591 {tiledUnpackOp}, {extractSlice.
getResult()}, generatedSlices};
600 resultOffsets = llvm::to_vector(offsets);
601 resultSizes = llvm::to_vector(sizes);
605 FailureOr<TilingResult>
609 FailureOr<TilingResult> tilingResult =
610 getTiledImplementation(op, b, offsets, sizes);
611 if (failed(tilingResult))
613 return tilingResult.value();
618 LogicalResult getIterationDomainTileFromOperandTile(
623 auto unPackOp = cast<UnPackOp>(op);
626 int64_t numTiles = unPackOp.getInnerDimsPos().size();
627 auto destOffsets = offsets.drop_back(numTiles);
628 auto destSizes = sizes.drop_back(numTiles);
631 int64_t outputRank = unPackOp.getDestRank();
634 applyPermToRange(origOffsets, origSizes,
638 unPackOp.getDimAndTileMapping();
640 for (
auto dim : llvm::seq<int64_t>(0, outputRank)) {
646 if (dimAndTileMapping.count(dim)) {
650 auto avOffset = AV(dim0).bind(origOffsets[dim]);
651 auto avSize = AV(dim0).bind(origSizes[dim]);
652 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
653 resultOffsets.push_back(ab.mul(avOffset, avTileSize));
654 resultSizes.push_back(ab.mul(avSize, avTileSize));
656 resultOffsets.push_back(origOffsets[dim]);
657 resultSizes.push_back(origSizes[dim]);
664 FailureOr<TilingResult> getTiledImplementationFromOperandTile(
667 auto unPackOp = cast<UnPackOp>(op);
670 int64_t numTiles = unPackOp.getInnerDimsPos().size();
672 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
682 if (failed(getIterationDomainTileFromOperandTile(
683 op, b, 0, offsets, sizes, outputOffsets,
688 int64_t outputRank = unPackOp.getDestRank();
693 auto extractDestSlice = b.
create<ExtractSliceOp>(
694 loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
695 tiledOperands.push_back(extractDestSlice);
698 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
700 auto extractSourceSlice = b.
create<ExtractSliceOp>(
701 loc, unPackOp.getSource(), offsets, sizes, strides);
702 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
703 for (
auto tile : unPackOp.getInnerTiles())
704 tiledOperands.push_back(
tile);
714 extractSourceSlice, extractDestSlice})};
724 bool generateZeroSliceGuard) {
726 Value padValue = padOp.getConstantPaddingValue();
761 bool hasZeroLen =
false;
764 Value dynHasZeroLenCond;
766 int64_t rank = padOp.getSourceType().getRank();
767 for (
unsigned dim = 0; dim < rank; ++dim) {
768 auto low = padOp.getMixedLowPad()[dim];
770 auto high = padOp.getMixedHighPad()[dim];
772 auto offset = offsets[dim];
773 auto length = sizes[dim];
782 newLows.push_back(newLow);
797 ?
min(
max(sub(offset, low), zero), srcSize)
798 :
min(offset, srcSize);
799 newOffsets.push_back(newOffset);
821 hasLowPad ?
min(
max(add(sub(offset, low), length), zero), srcSize)
822 :
min(add(offset, length), srcSize);
824 newLengths.push_back(newLength);
830 }
else if (!hasZeroLen) {
832 loc, arith::CmpIPredicate::eq,
837 ? b.
create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
846 hasHighPad ? sub(sub(length, newLength), newLow) : zero;
847 newHighs.push_back(newHigh);
857 RankedTensorType resultType =
862 if (resultType == val.getType())
864 return b.
create<tensor::CastOp>(loc, resultType, val);
870 auto createGenerateOp = [&]() {
872 auto generateOp = b.
create<tensor::GenerateOp>(
873 loc, resultType, dynDims,
875 builder.create<tensor::YieldOp>(gLoc, padValue);
882 auto createPadOfExtractSlice = [&]() {
884 auto newSliceOp = b.
create<tensor::ExtractSliceOp>(
885 loc, padOp.getSource(), newOffsets, newLengths, newStrides);
886 auto newPadOp = b.
create<PadOp>(
887 loc,
Type(), newSliceOp, newLows, newHighs,
893 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
896 return std::make_tuple(newPadOp, newSliceOp);
902 Operation *generateOp = createGenerateOp();
910 if (generateZeroSliceGuard && dynHasZeroLenCond) {
914 auto result = b.
create<scf::IfOp>(
915 loc, dynHasZeroLenCond,
918 thenOp = createGenerateOp();
919 b.create<scf::YieldOp>(loc, castResult(thenOp->
getResult(0)));
923 std::tie(elseOp, sliceOp) = createPadOfExtractSlice();
924 b.create<scf::YieldOp>(loc, castResult(elseOp->
getResult(0)));
930 auto [newPadOp, sliceOp] = createPadOfExtractSlice();
932 {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}};
938 tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
939 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
940 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
947 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
948 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
result_range getResults()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs)
Create IR to calculate (div lhs, rhs) and (mod lhs, rhs).
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
void registerTilingInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers external models for Tiling interface for tensor ops.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry ®istry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for result values of tiling.
Helper struct to build simple AffineValueExprs with minimal type inference support.
Holds the result of (div a, b) and (mod a, b).