27 struct PadOpTiling :
public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
30 auto padOp = cast<PadOp>(op);
32 padOp.getResultType().getRank(), utils::IteratorType::parallel);
44 for (
const auto &ub :
enumerate(reifiedShapes[0]))
45 loopRanges[ub.index()].size = ub.value();
49 FailureOr<TilingResult>
53 FailureOr<TilingResult> result =
57 return result.value();
66 resultOffsets.assign(offsets.begin(), offsets.end());
67 resultSizes.assign(sizes.begin(), sizes.end());
72 template <
typename OpTy>
75 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
76 "applies to only pack or unpack operations");
78 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
85 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
86 loopBounds[dim].offset = zero;
87 loopBounds[dim].stride = one;
88 loopBounds[dim].size = resultShape[0][dim];
96 if (permutation.empty())
98 applyPermutationToVector<OpFoldResult>(offsets, permutation);
99 applyPermutationToVector<OpFoldResult>(sizes, permutation);
103 :
public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
109 auto packOp = cast<PackOp>(op);
111 packOp.getSourceRank(), utils::IteratorType::parallel);
112 return iteratorTypes;
116 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
119 FailureOr<TilingResult>
123 auto packOp = cast<PackOp>(op);
128 int64_t inputRank = packOp.getSourceRank();
131 applyPermToRange(origOffsets, origSizes,
135 packOp.getDimAndTileMapping();
139 for (
auto dim : llvm::seq<int64_t>(0, inputRank)) {
145 if (dimAndTileMapping.count(dim)) {
149 auto avOffset = AV(dim0).bind(origOffsets[dim]);
150 auto avSize = AV(dim0).bind(origSizes[dim]);
151 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
152 inputIndices.push_back(ab.mul(avOffset, avTileSize));
153 inputSizes.push_back(ab.mul(avSize, avTileSize));
155 inputIndices.push_back(origOffsets[dim]);
156 inputSizes.push_back(origSizes[dim]);
160 if (packOp.getPaddingValue()) {
162 auto avDimSize = AV(dim0).bind(dimSize);
163 auto avInputIdx = AV(dim1).bind(inputIndices.back());
165 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
173 tiledOperands.push_back(b.
create<ExtractSliceOp>(
174 loc, packOp.getSource(), inputIndices, inputSizes, strides));
177 if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
181 strides.append(packOp.getDestRank() - inputRank, oneAttr);
182 auto extractSlice = b.
create<ExtractSliceOp>(
183 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
184 tiledOperands.push_back(extractSlice);
186 if (
auto val = packOp.getPaddingValue())
187 tiledOperands.push_back(val);
188 for (
auto tile : packOp.getInnerTiles())
189 tiledOperands.push_back(
tile);
208 auto packOp = cast<PackOp>(op);
209 int64_t inputRank = packOp.getSourceRank();
210 int64_t outputRank = packOp.getDestRank();
212 resultOffsets.assign(offsets.begin(), offsets.end());
213 resultOffsets.append(outputRank - inputRank, zeroAttr);
217 resultSizes.assign(sizes.begin(), sizes.end());
218 for (
auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
219 resultSizes.push_back(outputShape[0][dataTileDim]);
224 FailureOr<TilingResult>
228 auto packOp = cast<PackOp>(op);
229 int64_t numTiles = packOp.getInnerDimsPos().size();
234 for (
auto offset : offsets.take_back(numTiles))
239 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
243 FailureOr<TilingResult> tilingResult = getTiledImplementation(
244 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
245 if (failed(tilingResult))
247 return tilingResult.value();
251 struct UnpackTileDimInfo {
252 bool isAlignedToInnerTileSize;
262 static UnpackTileDimInfo getUnpackTileDimInfo(
OpBuilder &b, UnPackOp unpackOp,
266 UnpackTileDimInfo info;
270 unpackOp.getDimAndTileMapping();
272 if (!dimAndTileMapping.count(tileDim)) {
273 info.isAlignedToInnerTileSize =
true;
274 info.sourceOffset = tileOffset;
275 info.sourceSize = tileSize;
276 info.resultOffset = zeroAttr;
277 info.destExpandedSize = tileSize;
288 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
290 info.isAlignedToInnerTileSize =
false;
295 if (!failed(cstSize) && cstInnerSize) {
296 if (*cstSize % *cstInnerSize == 0)
297 info.isAlignedToInnerTileSize =
true;
301 if (*cstInnerSize == *cstSize) {
302 auto lhs = AV(dim0).bind(tileOffset);
303 auto rhs = AV(dim1).bind(innerTileSize);
304 info.sourceOffset = ab.floor(lhs, rhs);
305 info.sourceSize = oneAttr;
306 info.resultOffset = zeroAttr;
307 info.destExpandedSize = tileSize;
312 if (info.isAlignedToInnerTileSize) {
314 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
315 info.resultOffset = zeroAttr;
316 info.destExpandedSize = tileSize;
325 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
333 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
338 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
342 AV(dim1).bind(firstCoord.
quotient));
344 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
345 info.sourceOffset = firstCoord.
quotient;
346 info.resultOffset = firstCoord.
remainder;
355 struct UnPackOpTiling
356 :
public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
359 auto unpackOp = cast<UnPackOp>(op);
361 unpackOp.getDestRank(), utils::IteratorType::parallel);
362 return iteratorTypes;
366 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
383 FailureOr<TilingResult>
387 auto unpackOp = cast<UnPackOp>(op);
388 int64_t srcRank = unpackOp.getSourceRank();
389 int64_t destRank = unpackOp.getDestRank();
390 int64_t numInnerTiles = srcRank - destRank;
396 bool isPerfectTilingCase =
true;
401 for (
auto dim : llvm::seq<int64_t>(0, destRank)) {
402 UnpackTileDimInfo info =
403 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
404 if (!info.isAlignedToInnerTileSize)
405 isPerfectTilingCase =
false;
406 sliceSrcIndices.push_back(info.sourceOffset);
407 sliceSrcSizes.push_back(info.sourceSize);
408 destExpandedSizes.push_back(info.destExpandedSize);
409 resultOffsetsFromDest.push_back(info.resultOffset);
414 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
415 unpackOp.getOuterDimsPerm());
417 sliceSrcIndices.append(numInnerTiles, zeroAttr);
418 sliceSrcSizes.append(unpackOp.getMixedTiles());
419 sliceSrcStrides.append(numInnerTiles, oneAttr);
421 b.
create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
422 sliceSrcSizes, sliceSrcStrides);
426 if (isPerfectTilingCase) {
427 sliceDest = b.
create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
430 sliceDest = b.
create<EmptyOp>(loc, destExpandedSizes,
431 unpackOp.getDestType().getElementType());
435 for (
auto tile : unpackOp.getInnerTiles())
436 tiledOperands.push_back(
tile);
441 if (isPerfectTilingCase)
447 resultOffsetsFromDest, sizes, destStrides);
457 resultOffsets = llvm::to_vector(offsets);
458 resultSizes = llvm::to_vector(sizes);
462 FailureOr<TilingResult>
466 FailureOr<TilingResult> tilingResult =
467 getTiledImplementation(op, b, offsets, sizes);
468 if (failed(tilingResult))
470 return tilingResult.value();
475 LogicalResult getIterationDomainTileFromOperandTile(
480 auto unPackOp = cast<UnPackOp>(op);
483 int64_t numTiles = unPackOp.getInnerDimsPos().size();
484 auto destOffsets = offsets.drop_back(numTiles);
485 auto destSizes = sizes.drop_back(numTiles);
488 int64_t outputRank = unPackOp.getDestRank();
492 applyPermToRange(origOffsets, origSizes,
496 unPackOp.getDimAndTileMapping();
498 for (
auto dim : llvm::seq<int64_t>(0, outputRank)) {
504 if (dimAndTileMapping.count(dim)) {
508 auto avOffset = AV(dim0).bind(origOffsets[dim]);
509 auto avSize = AV(dim0).bind(origSizes[dim]);
510 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
511 resultOffsets.push_back(ab.mul(avOffset, avTileSize));
512 resultSizes.push_back(ab.mul(avSize, avTileSize));
514 resultOffsets.push_back(origOffsets[dim]);
515 resultSizes.push_back(origSizes[dim]);
522 FailureOr<TilingResult> getTiledImplementationFromOperandTile(
525 auto unPackOp = cast<UnPackOp>(op);
528 int64_t numTiles = unPackOp.getInnerDimsPos().size();
530 llvm::zip_equal(unPackOp.getMixedTiles(), sizes.take_back(numTiles))) {
540 if (failed(getIterationDomainTileFromOperandTile(
541 op, b, 0, offsets, sizes, outputOffsets,
546 int64_t outputRank = unPackOp.getDestRank();
551 auto extractDestSlice = b.
create<ExtractSliceOp>(
552 loc, unPackOp.getDest(), outputOffsets, outputSizes, strides);
553 tiledOperands.push_back(extractDestSlice);
556 strides.append(unPackOp.getSourceRank() - outputRank, oneAttr);
558 auto extractSourceSlice = b.
create<ExtractSliceOp>(
559 loc, unPackOp.getSource(), offsets, sizes, strides);
560 tiledOperands.insert(tiledOperands.begin(), extractSourceSlice);
561 for (
auto tile : unPackOp.getInnerTiles())
562 tiledOperands.push_back(
tile);
580 bool generateZeroSliceGuard) {
582 Value padValue = padOp.getConstantPaddingValue();
617 bool hasZeroLen =
false;
620 Value dynHasZeroLenCond;
622 int64_t rank = padOp.getSourceType().getRank();
623 for (
unsigned dim = 0; dim < rank; ++dim) {
624 auto low = padOp.getMixedLowPad()[dim];
626 auto high = padOp.getMixedHighPad()[dim];
628 auto offset = offsets[dim];
629 auto length = sizes[dim];
638 newLows.push_back(newLow);
653 ?
min(
max(sub(offset, low), zero), srcSize)
654 :
min(offset, srcSize);
655 newOffsets.push_back(newOffset);
677 hasLowPad ?
min(
max(add(sub(offset, low), length), zero), srcSize)
678 :
min(add(offset, length), srcSize);
680 newLengths.push_back(newLength);
686 }
else if (!hasZeroLen) {
688 loc, arith::CmpIPredicate::eq,
693 ? b.
create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
702 hasHighPad ? sub(sub(length, newLength), newLow) : zero;
703 newHighs.push_back(newHigh);
713 RankedTensorType resultType =
718 if (resultType == val.getType())
720 return b.
create<tensor::CastOp>(loc, resultType, val);
726 auto createGenerateOp = [&]() {
728 auto generateOp = b.
create<tensor::GenerateOp>(
729 loc, resultType, dynDims,
731 builder.create<tensor::YieldOp>(gLoc, padValue);
738 auto createPadOfExtractSlice = [&]() {
740 Value newSliceOp = b.
create<tensor::ExtractSliceOp>(
741 loc, padOp.getSource(), newOffsets, newLengths, newStrides);
742 auto newPadOp = b.
create<PadOp>(
743 loc,
Type(), newSliceOp, newLows, newHighs,
749 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
758 Operation *generateOp = createGenerateOp();
764 if (generateZeroSliceGuard && dynHasZeroLenCond) {
767 auto result = b.
create<scf::IfOp>(
768 loc, dynHasZeroLenCond,
771 thenOp = createGenerateOp();
772 b.create<scf::YieldOp>(loc, castResult(thenOp->
getResult(0)));
776 elseOp = createPadOfExtractSlice();
777 b.create<scf::YieldOp>(loc, castResult(elseOp->
getResult(0)));
782 Operation *newPadOp = createPadOfExtractSlice();
789 tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
790 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
791 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
798 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
799 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
result_range getResults()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs)
Create IR to calculate (div lhs, rhs) and (mod lhs, rhs).
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
void registerTilingInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers external models for Tiling interface for tensor ops.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry ®istry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for result values of tiling.
Helper struct to build simple AffineValueExprs with minimal type inference support.
Holds the result of (div a, b) and (mod a, b).