27 struct PadOpTiling :
public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
30 auto padOp = cast<PadOp>(op);
32 padOp.getResultType().getRank(), utils::IteratorType::parallel);
44 for (
const auto &ub :
enumerate(reifiedShapes[0]))
45 loopRanges[ub.index()].size = ub.value();
57 return result.value();
66 resultOffsets.assign(offsets.begin(), offsets.end());
67 resultSizes.assign(sizes.begin(), sizes.end());
72 template <
typename OpTy>
75 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
76 "applies to only pack or unpack operations");
78 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
85 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
86 loopBounds[dim].offset = zero;
87 loopBounds[dim].stride = one;
88 loopBounds[dim].size = resultShape[0][dim];
96 if (permutation.empty())
98 applyPermutationToVector<OpFoldResult>(offsets, permutation);
99 applyPermutationToVector<OpFoldResult>(sizes, permutation);
103 :
public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
109 auto packOp = cast<PackOp>(op);
111 packOp.getSourceRank(), utils::IteratorType::parallel);
112 return iteratorTypes;
116 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
123 auto packOp = cast<PackOp>(op);
128 int64_t inputRank = packOp.getSourceRank();
131 applyPermToRange(origOffsets, origSizes,
135 packOp.getDimAndTileMapping();
139 for (
auto dim : llvm::seq<int64_t>(0, inputRank)) {
145 if (dimAndTileMapping.count(dim)) {
149 auto avOffset = AV(dim0).bind(origOffsets[dim]);
150 auto avSize = AV(dim0).bind(origSizes[dim]);
151 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
152 inputIndices.push_back(ab.mul(avOffset, avTileSize));
153 inputSizes.push_back(ab.mul(avSize, avTileSize));
155 inputIndices.push_back(origOffsets[dim]);
156 inputSizes.push_back(origSizes[dim]);
160 if (packOp.getPaddingValue()) {
162 auto avDimSize = AV(dim0).bind(dimSize);
163 auto avInputIdx = AV(dim1).bind(inputIndices.back());
165 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
173 tiledOperands.push_back(b.
create<ExtractSliceOp>(
174 loc, packOp.getSource(), inputIndices, inputSizes, strides));
177 if (
failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
181 strides.append(packOp.getDestRank() - inputRank, oneAttr);
182 auto extractSlice = b.
create<ExtractSliceOp>(
183 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
184 tiledOperands.push_back(extractSlice);
186 if (
auto val = packOp.getPaddingValue())
187 tiledOperands.push_back(val);
188 for (
auto tile : packOp.getInnerTiles())
189 tiledOperands.push_back(
tile);
208 auto packOp = cast<PackOp>(op);
209 int64_t inputRank = packOp.getSourceRank();
210 int64_t outputRank = packOp.getDestRank();
212 resultOffsets.assign(offsets.begin(), offsets.end());
213 resultOffsets.append(outputRank - inputRank, zeroAttr);
217 resultSizes.assign(sizes.begin(), sizes.end());
218 for (
auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
219 resultSizes.push_back(outputShape[0][dataTileDim]);
228 auto packOp = cast<PackOp>(op);
229 int64_t numTiles = packOp.getInnerDimsPos().size();
234 for (
auto offset : offsets.take_back(numTiles))
239 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
244 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
247 return tilingResult.value();
251 struct UnpackTileDimInfo {
252 bool isAlignedToInnerTileSize;
262 static UnpackTileDimInfo getUnpackTileDimInfo(
OpBuilder &b, UnPackOp unpackOp,
266 UnpackTileDimInfo info;
270 unpackOp.getDimAndTileMapping();
272 if (!dimAndTileMapping.count(tileDim)) {
273 info.isAlignedToInnerTileSize =
true;
274 info.sourceOffset = tileOffset;
275 info.sourceSize = tileSize;
276 info.resultOffset = zeroAttr;
277 info.destExpandedSize = tileSize;
288 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
290 info.isAlignedToInnerTileSize =
false;
296 if (!
failed(cstSize) && cstInnerSize) {
297 if (*cstSize % *cstInnerSize == 0)
298 info.isAlignedToInnerTileSize =
true;
302 if (*cstInnerSize == *cstSize) {
303 auto lhs = AV(dim0).bind(tileOffset);
304 auto rhs = AV(dim1).bind(innerTileSize);
305 info.sourceOffset = ab.floor(lhs, rhs);
306 info.sourceSize = oneAttr;
307 info.resultOffset = zeroAttr;
308 info.destExpandedSize = tileSize;
313 if (info.isAlignedToInnerTileSize) {
315 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
316 info.resultOffset = zeroAttr;
317 info.destExpandedSize = tileSize;
326 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
334 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
339 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
343 AV(dim1).bind(firstCoord.
quotient));
345 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
346 info.sourceOffset = firstCoord.
quotient;
347 info.resultOffset = firstCoord.
remainder;
356 struct UnPackOpTiling
357 :
public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
360 auto unpackOp = cast<UnPackOp>(op);
362 unpackOp.getDestRank(), utils::IteratorType::parallel);
363 return iteratorTypes;
367 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
388 auto unpackOp = cast<UnPackOp>(op);
389 int64_t srcRank = unpackOp.getSourceRank();
390 int64_t destRank = unpackOp.getDestRank();
391 int64_t numInnerTiles = srcRank - destRank;
397 bool isPerfectTilingCase =
true;
402 for (
auto dim : llvm::seq<int64_t>(0, destRank)) {
403 UnpackTileDimInfo info =
404 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
405 if (!info.isAlignedToInnerTileSize)
406 isPerfectTilingCase =
false;
407 sliceSrcIndices.push_back(info.sourceOffset);
408 sliceSrcSizes.push_back(info.sourceSize);
409 destExpandedSizes.push_back(info.destExpandedSize);
410 resultOffsetsFromDest.push_back(info.resultOffset);
415 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
416 unpackOp.getOuterDimsPerm());
418 sliceSrcIndices.append(numInnerTiles, zeroAttr);
419 sliceSrcSizes.append(unpackOp.getMixedTiles());
420 sliceSrcStrides.append(numInnerTiles, oneAttr);
422 b.
create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
423 sliceSrcSizes, sliceSrcStrides);
427 if (isPerfectTilingCase) {
428 sliceDest = b.
create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
431 sliceDest = b.
create<EmptyOp>(loc, destExpandedSizes,
432 unpackOp.getDestType().getElementType());
436 for (
auto tile : unpackOp.getInnerTiles())
437 tiledOperands.push_back(
tile);
442 if (isPerfectTilingCase)
448 resultOffsetsFromDest, sizes, destStrides);
458 resultOffsets = llvm::to_vector(offsets);
459 resultSizes = llvm::to_vector(sizes);
468 getTiledImplementation(op, b, offsets, sizes);
471 return tilingResult.value();
481 bool generateZeroSliceGuard) {
483 Value padValue = padOp.getConstantPaddingValue();
518 bool hasZeroLen =
false;
521 Value dynHasZeroLenCond;
523 int64_t rank = padOp.getSourceType().getRank();
524 for (
unsigned dim = 0; dim < rank; ++dim) {
525 auto low = padOp.getMixedLowPad()[dim];
527 auto high = padOp.getMixedHighPad()[dim];
529 auto offset = offsets[dim];
530 auto length = sizes[dim];
539 newLows.push_back(newLow);
554 ?
min(
max(sub(offset, low), zero), srcSize)
555 :
min(offset, srcSize);
556 newOffsets.push_back(newOffset);
578 hasLowPad ?
min(
max(add(sub(offset, low), length), zero), srcSize)
579 :
min(add(offset, length), srcSize);
581 newLengths.push_back(newLength);
587 }
else if (!hasZeroLen) {
589 loc, arith::CmpIPredicate::eq,
594 ? b.
create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
603 hasHighPad ? sub(sub(length, newLength), newLow) : zero;
604 newHighs.push_back(newHigh);
614 RankedTensorType resultType =
619 if (resultType == val.getType())
621 return b.
create<tensor::CastOp>(loc, resultType, val);
627 auto createGenerateOp = [&]() {
629 auto generateOp = b.
create<tensor::GenerateOp>(
630 loc, resultType, dynDims,
632 builder.create<tensor::YieldOp>(gLoc, padValue);
639 auto createPadOfExtractSlice = [&]() {
641 Value newSliceOp = b.
create<tensor::ExtractSliceOp>(
642 loc, padOp.getSource(), newOffsets, newLengths, newStrides);
643 auto newPadOp = b.
create<PadOp>(
644 loc,
Type(), newSliceOp, newLows, newHighs,
650 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
659 Operation *generateOp = createGenerateOp();
665 if (generateZeroSliceGuard && dynHasZeroLenCond) {
668 auto result = b.
create<scf::IfOp>(
669 loc, dynHasZeroLenCond,
672 thenOp = createGenerateOp();
673 b.create<scf::YieldOp>(loc, castResult(thenOp->
getResult(0)));
677 elseOp = createPadOfExtractSlice();
678 b.create<scf::YieldOp>(loc, castResult(elseOp->
getResult(0)));
683 Operation *newPadOp = createPadOfExtractSlice();
690 tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
691 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
692 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
699 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
700 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
result_range getResults()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, Value value, std::optional< int64_t > dim=std::nullopt, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given affine map, where dims and symbols are bound to the given oper...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs)
Create IR to calculate (div lhs, rhs) and (mod lhs, rhs).
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
void registerTilingInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers external models for Tiling interface for tensor ops.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry ®istry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
Container for result values of tiling.
Helper struct to build simple AffineValueExprs with minimal type inference support.
Holds the result of (div a, b) and (mod a, b).