26 struct PadOpTiling :
public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
29 auto padOp = cast<PadOp>(op);
31 padOp.getResultType().getRank(), utils::IteratorType::parallel);
37 ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
38 dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
39 (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes);
42 Value zero = b.
create<arith::ConstantIndexOp>(loc, 0);
43 Value one = b.
create<arith::ConstantIndexOp>(loc, 1);
47 for (
const auto &ub :
enumerate(reifiedShapes[0]))
48 loopRanges[ub.index()].size = ub.value();
69 resultOffsets.assign(offsets.begin(), offsets.end());
70 resultSizes.assign(sizes.begin(), sizes.end());
75 template <
typename OpTy>
78 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
79 "applies to only pack or unpack operations");
82 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
84 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
85 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
87 (void)op.reifyResultShapes(builder, resultShape);
89 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
90 loopBounds[dim].offset = zero;
91 loopBounds[dim].stride = one;
92 loopBounds[dim].size = resultShape[0][dim];
100 if (permutation.empty())
102 applyPermutationToVector<OpFoldResult>(offsets, permutation);
103 applyPermutationToVector<OpFoldResult>(sizes, permutation);
107 :
public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
113 auto packOp = cast<PackOp>(op);
115 packOp.getSourceRank(), utils::IteratorType::parallel);
116 return iteratorTypes;
120 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
127 auto packOp = cast<PackOp>(op);
132 int64_t inputRank = packOp.getSourceRank();
135 applyPermToRange(origOffsets, origSizes,
139 packOp.getDimAndTileMapping();
143 for (
auto dim : llvm::seq<int64_t>(0, inputRank)) {
149 if (dimAndTileMapping.count(dim)) {
153 auto avOffset = AV(dim0).bind(origOffsets[dim]);
154 auto avSize = AV(dim0).bind(origSizes[dim]);
155 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
156 inputIndices.push_back(ab.mul(avOffset, avTileSize));
157 inputSizes.push_back(ab.mul(avSize, avTileSize));
159 inputIndices.push_back(origOffsets[dim]);
160 inputSizes.push_back(origSizes[dim]);
165 auto avDimSize = AV(dim0).bind(dimSize);
166 auto avInputIdx = AV(dim1).bind(inputIndices.back());
168 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
175 tiledOperands.push_back(b.
create<ExtractSliceOp>(
176 loc, packOp.getSource(), inputIndices, inputSizes, strides));
179 if (
failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
183 strides.append(packOp.getDestRank() - inputRank, oneAttr);
184 auto extractSlice = b.
create<ExtractSliceOp>(
185 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
186 tiledOperands.push_back(extractSlice);
188 if (
auto val = packOp.getPaddingValue())
189 tiledOperands.push_back(val);
190 for (
auto tile : packOp.getInnerTiles())
191 tiledOperands.push_back(
tile);
196 return {tiledPackOp};
209 auto packOp = cast<PackOp>(op);
210 int64_t inputRank = packOp.getSourceRank();
211 int64_t outputRank = packOp.getDestRank();
213 resultOffsets.assign(offsets.begin(), offsets.end());
214 resultOffsets.append(outputRank - inputRank, zeroAttr);
217 (void)packOp.reifyResultShapes(b, outputShape);
218 resultSizes.assign(sizes.begin(), sizes.end());
219 for (
auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
226 struct UnpackTileDimInfo {
227 bool isAlignedToInnerTileSize;
237 static UnpackTileDimInfo getUnpackTileDimInfo(
OpBuilder &b, UnPackOp unpackOp,
241 UnpackTileDimInfo info;
245 unpackOp.getDimAndTileMapping();
247 if (!dimAndTileMapping.count(tileDim)) {
248 info.isAlignedToInnerTileSize =
true;
249 info.sourceOffset = tileOffset;
250 info.sourceSize = tileSize;
251 info.resultOffset = zeroAttr;
252 info.destExpandedSize = tileSize;
263 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
265 info.isAlignedToInnerTileSize =
false;
269 if (!
failed(cstSize) && cstInnerSize) {
270 if (*cstSize % *cstInnerSize == 0)
271 info.isAlignedToInnerTileSize =
true;
275 if (*cstInnerSize == *cstSize) {
276 auto lhs = AV(dim0).bind(tileOffset);
277 auto rhs = AV(dim1).bind(innerTileSize);
278 info.sourceOffset = ab.floor(lhs, rhs);
279 info.sourceSize = oneAttr;
280 info.resultOffset = zeroAttr;
281 info.destExpandedSize = tileSize;
286 if (info.isAlignedToInnerTileSize) {
288 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
289 info.resultOffset = zeroAttr;
290 info.destExpandedSize = tileSize;
299 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
307 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
312 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
316 AV(dim1).bind(firstCoord.
quotient));
318 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
319 info.sourceOffset = firstCoord.
quotient;
320 info.resultOffset = firstCoord.
remainder;
321 info.destExpandedSize =
322 ab.mul(AV(dim0).bind(info.sourceSize), AV(sym0).bind(innerTileSize));
326 struct UnPackOpTiling
327 :
public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
330 auto unpackOp = cast<UnPackOp>(op);
332 unpackOp.getDestRank(), utils::IteratorType::parallel);
333 return iteratorTypes;
337 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
358 auto unpackOp = cast<UnPackOp>(op);
359 int64_t srcRank = unpackOp.getSourceRank();
360 int64_t destRank = unpackOp.getDestRank();
361 int64_t numInnerTiles = srcRank - destRank;
367 bool isPerfectTilingCase =
true;
372 for (
auto dim : llvm::seq<int64_t>(0, destRank)) {
373 UnpackTileDimInfo info =
374 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
375 if (!info.isAlignedToInnerTileSize)
376 isPerfectTilingCase =
false;
377 sliceSrcIndices.push_back(info.sourceOffset);
378 sliceSrcSizes.push_back(info.sourceSize);
379 destExpandedSizes.push_back(info.destExpandedSize);
380 resultOffsetsFromDest.push_back(info.resultOffset);
385 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
386 unpackOp.getOuterDimsPerm());
388 sliceSrcIndices.append(numInnerTiles, zeroAttr);
389 sliceSrcSizes.append(unpackOp.getMixedTiles());
390 sliceSrcStrides.append(numInnerTiles, oneAttr);
392 b.
create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
393 sliceSrcSizes, sliceSrcStrides);
397 if (isPerfectTilingCase) {
398 sliceDest = b.
create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
401 sliceDest = b.
create<EmptyOp>(loc, destExpandedSizes,
402 unpackOp.getDestType().getElementType());
409 if (isPerfectTilingCase)
410 return {tiledUnpackOp};
414 resultOffsetsFromDest, sizes, destStrides);
415 return {tiledUnpackOp, extractSlice};
424 resultOffsets = llvm::to_vector(offsets);
425 resultSizes = llvm::to_vector(sizes);
430 unsigned resultNumber,
433 return getTiledImplementation(op, b, offsets, sizes)
435 ->getResult(resultNumber);
444 bool generateZeroSliceGuard) {
446 Value padValue = padOp.getConstantPaddingValue();
482 staticIndices.push_back(*constInt);
484 staticIndices.push_back(ShapedType::kDynamic);
485 dynIndices.push_back(val);
494 bool hasZeroLen =
false;
497 Value dynHasZeroLenCond;
499 int64_t rank = padOp.getSourceType().getRank();
500 for (
unsigned dim = 0; dim < rank; ++dim) {
509 auto srcSize = b.
createOrFold<tensor::DimOp>(loc, padOp.getSource(), dim);
516 Value newLow = hasLowPad ?
max(zero, sub(low, offset)) : zero;
517 appendIndex(newLow, newLows, staticNewLows);
531 Value newOffset = hasLowPad ?
min(
max(sub(offset, low), zero), srcSize)
532 :
min(offset, srcSize);
554 Value endLoc = hasLowPad
555 ?
min(
max(add(sub(offset, low), length), zero), srcSize)
556 :
min(add(offset, length), srcSize);
557 Value newLength = sub(endLoc, newOffset);
563 hasZeroLen |= *newLengthInt == 0;
565 Value check = b.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
569 ? b.
create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
577 Value newHigh = hasHighPad ? sub(sub(length, newLength), newLow) : zero;
578 appendIndex(newHigh, newHighs, staticNewHighs);
588 RankedTensorType resultType =
589 RankedTensorType::get(shape, padOp.getResultType().getElementType());
593 return b.
create<tensor::CastOp>(loc, resultType, val);
599 auto createGenerateOp = [&]() {
601 auto generateOp = b.
create<tensor::GenerateOp>(
602 loc, resultType, dynDims,
604 builder.create<tensor::YieldOp>(gLoc, padValue);
606 return castResult(generateOp);
611 auto createPadOfExtractSlice = [&]() {
613 auto newSliceOp = b.
create<tensor::ExtractSliceOp>(
614 loc, padOp.getSource(), newOffsets, newLengths, newStrides);
615 auto newPadOp = b.
create<PadOp>(loc, newSliceOp, staticNewLows,
616 staticNewHighs, newLows, newHighs);
620 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
623 return castResult(newPadOp);
629 return createGenerateOp();
633 if (generateZeroSliceGuard && dynHasZeroLenCond) {
634 auto result = b.
create<scf::IfOp>(
635 loc, dynHasZeroLenCond,
638 b.create<scf::YieldOp>(loc, createGenerateOp()->getResult(0));
642 b.create<scf::YieldOp>(loc, createPadOfExtractSlice()->getResult(0));
646 return createPadOfExtractSlice();
652 tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
653 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
654 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
661 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
662 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
This class provides an abstraction over the various different ranges of value types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns an integer of index type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< int64_t > getConstantUpperBoundForIndex(Value value)
Returns a constant upper bound for the result value of an index computation.
void registerTilingInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers external models for Tiling interface for tensor ops.
Operation * bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry ®istry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
SmallVector< OpFoldResult > createDimValues(OpBuilder &b, Location loc, Value rankedTensor)
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs)
Create IR to calculate (div lhs, rhs) and (mod lhs, rhs).
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Helper struct to build simple AffineValueExprs with minimal type inference support.
Holds the result of (div a, b) and (mod a, b).
This class represents an efficient way to signal success or failure.