27 struct PadOpTiling :
public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
30 auto padOp = cast<PadOp>(op);
32 padOp.getResultType().getRank(), utils::IteratorType::parallel);
44 for (
const auto &ub :
enumerate(reifiedShapes[0]))
45 loopRanges[ub.index()].size = ub.value();
57 return result.value();
66 resultOffsets.assign(offsets.begin(), offsets.end());
67 resultSizes.assign(sizes.begin(), sizes.end());
72 template <
typename OpTy>
75 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
76 "applies to only pack or unpack operations");
78 int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getSourceRank()
85 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
86 loopBounds[dim].offset = zero;
87 loopBounds[dim].stride = one;
88 loopBounds[dim].size = resultShape[0][dim];
96 if (permutation.empty())
98 applyPermutationToVector<OpFoldResult>(offsets, permutation);
99 applyPermutationToVector<OpFoldResult>(sizes, permutation);
103 :
public TilingInterface::ExternalModel<PackOpTiling, PackOp> {
109 auto packOp = cast<PackOp>(op);
111 packOp.getSourceRank(), utils::IteratorType::parallel);
112 return iteratorTypes;
116 return getPackUnPackIterationDomain<PackOp>(cast<PackOp>(op), b);
123 auto packOp = cast<PackOp>(op);
128 int64_t inputRank = packOp.getSourceRank();
131 applyPermToRange(origOffsets, origSizes,
135 packOp.getDimAndTileMapping();
139 for (
auto dim : llvm::seq<int64_t>(0, inputRank)) {
145 if (dimAndTileMapping.count(dim)) {
149 auto avOffset = AV(dim0).bind(origOffsets[dim]);
150 auto avSize = AV(dim0).bind(origSizes[dim]);
151 auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
152 inputIndices.push_back(ab.mul(avOffset, avTileSize));
153 inputSizes.push_back(ab.mul(avSize, avTileSize));
155 inputIndices.push_back(origOffsets[dim]);
156 inputSizes.push_back(origSizes[dim]);
160 if (packOp.getPaddingValue()) {
162 auto avDimSize = AV(dim0).bind(dimSize);
163 auto avInputIdx = AV(dim1).bind(inputIndices.back());
165 ab.min({inputSizes.back(), ab.sub(avDimSize, avInputIdx)});
173 tiledOperands.push_back(b.
create<ExtractSliceOp>(
174 loc, packOp.getSource(), inputIndices, inputSizes, strides));
177 if (
failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets,
181 strides.append(packOp.getDestRank() - inputRank, oneAttr);
182 auto extractSlice = b.
create<ExtractSliceOp>(
183 loc, packOp.getDest(), outputOffsets, outputSizes, strides);
184 tiledOperands.push_back(extractSlice);
186 if (
auto val = packOp.getPaddingValue())
187 tiledOperands.push_back(val);
188 for (
auto tile : packOp.getInnerTiles())
189 tiledOperands.push_back(
tile);
208 auto packOp = cast<PackOp>(op);
209 int64_t inputRank = packOp.getSourceRank();
210 int64_t outputRank = packOp.getDestRank();
212 resultOffsets.assign(offsets.begin(), offsets.end());
213 resultOffsets.append(outputRank - inputRank, zeroAttr);
217 resultSizes.assign(sizes.begin(), sizes.end());
218 for (
auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
219 resultSizes.push_back(outputShape[0][dataTileDim]);
228 auto packOp = cast<PackOp>(op);
229 int64_t numTiles = packOp.getInnerDimsPos().size();
234 for (
auto offset : offsets.take_back(numTiles))
239 llvm::zip_equal(packOp.getMixedTiles(), sizes.take_back(numTiles)))
244 op, b, offsets.drop_back(numTiles), sizes.drop_back(numTiles));
247 return tilingResult.value();
251 struct UnpackTileDimInfo {
252 bool isAlignedToInnerTileSize;
262 static UnpackTileDimInfo getUnpackTileDimInfo(
OpBuilder &b, UnPackOp unpackOp,
266 UnpackTileDimInfo info;
270 unpackOp.getDimAndTileMapping();
272 if (!dimAndTileMapping.count(tileDim)) {
273 info.isAlignedToInnerTileSize =
true;
274 info.sourceOffset = tileOffset;
275 info.sourceSize = tileSize;
276 info.resultOffset = zeroAttr;
277 info.destExpandedSize = tileSize;
288 OpFoldResult innerTileSize = dimAndTileMapping[tileDim];
290 info.isAlignedToInnerTileSize =
false;
295 if (!
failed(cstSize) && cstInnerSize) {
296 if (*cstSize % *cstInnerSize == 0)
297 info.isAlignedToInnerTileSize =
true;
301 if (*cstInnerSize == *cstSize) {
302 auto lhs = AV(dim0).bind(tileOffset);
303 auto rhs = AV(dim1).bind(innerTileSize);
304 info.sourceOffset = ab.floor(lhs, rhs);
305 info.sourceSize = oneAttr;
306 info.resultOffset = zeroAttr;
307 info.destExpandedSize = tileSize;
312 if (info.isAlignedToInnerTileSize) {
314 ab.floor(AV(dim0).bind(tileOffset), AV(dim1).bind(innerTileSize));
315 info.resultOffset = zeroAttr;
316 info.destExpandedSize = tileSize;
325 ab.ceil(AV(dim0).bind(tileSize), AV(dim1).bind(innerTileSize));
333 ab.add(AV(dim0).bind(tileOffset), AV(dim1).bind(tileSize));
338 ab.sub(AV(dim0).bind(tileExclusiveBound), AV(dim1).bind(oneAttr))),
342 AV(dim1).bind(firstCoord.
quotient));
344 ab.add(AV(dim0).bind(lengthMinusOne), AV(dim1).bind(oneAttr));
345 info.sourceOffset = firstCoord.
quotient;
346 info.resultOffset = firstCoord.
remainder;
355 struct UnPackOpTiling
356 :
public TilingInterface::ExternalModel<UnPackOpTiling, UnPackOp> {
359 auto unpackOp = cast<UnPackOp>(op);
361 unpackOp.getDestRank(), utils::IteratorType::parallel);
362 return iteratorTypes;
366 return getPackUnPackIterationDomain<UnPackOp>(cast<UnPackOp>(op), b);
387 auto unpackOp = cast<UnPackOp>(op);
388 int64_t srcRank = unpackOp.getSourceRank();
389 int64_t destRank = unpackOp.getDestRank();
390 int64_t numInnerTiles = srcRank - destRank;
396 bool isPerfectTilingCase =
true;
401 for (
auto dim : llvm::seq<int64_t>(0, destRank)) {
402 UnpackTileDimInfo info =
403 getUnpackTileDimInfo(b, unpackOp, dim, offsets[dim], sizes[dim]);
404 if (!info.isAlignedToInnerTileSize)
405 isPerfectTilingCase =
false;
406 sliceSrcIndices.push_back(info.sourceOffset);
407 sliceSrcSizes.push_back(info.sourceSize);
408 destExpandedSizes.push_back(info.destExpandedSize);
409 resultOffsetsFromDest.push_back(info.resultOffset);
414 applyPermToRange(sliceSrcIndices, sliceSrcSizes,
415 unpackOp.getOuterDimsPerm());
417 sliceSrcIndices.append(numInnerTiles, zeroAttr);
418 sliceSrcSizes.append(unpackOp.getMixedTiles());
419 sliceSrcStrides.append(numInnerTiles, oneAttr);
421 b.
create<ExtractSliceOp>(loc, unpackOp.getSource(), sliceSrcIndices,
422 sliceSrcSizes, sliceSrcStrides);
426 if (isPerfectTilingCase) {
427 sliceDest = b.
create<ExtractSliceOp>(loc, unpackOp.getDest(), offsets,
430 sliceDest = b.
create<EmptyOp>(loc, destExpandedSizes,
431 unpackOp.getDestType().getElementType());
435 for (
auto tile : unpackOp.getInnerTiles())
436 tiledOperands.push_back(
tile);
441 if (isPerfectTilingCase)
447 resultOffsetsFromDest, sizes, destStrides);
457 resultOffsets = llvm::to_vector(offsets);
458 resultSizes = llvm::to_vector(sizes);
467 getTiledImplementation(op, b, offsets, sizes);
470 return tilingResult.value();
480 bool generateZeroSliceGuard) {
482 Value padValue = padOp.getConstantPaddingValue();
517 bool hasZeroLen =
false;
520 Value dynHasZeroLenCond;
522 int64_t rank = padOp.getSourceType().getRank();
523 for (
unsigned dim = 0; dim < rank; ++dim) {
524 auto low = padOp.getMixedLowPad()[dim];
526 auto high = padOp.getMixedHighPad()[dim];
528 auto offset = offsets[dim];
529 auto length = sizes[dim];
538 newLows.push_back(newLow);
553 ?
min(
max(sub(offset, low), zero), srcSize)
554 :
min(offset, srcSize);
555 newOffsets.push_back(newOffset);
577 hasLowPad ?
min(
max(add(sub(offset, low), length), zero), srcSize)
578 :
min(add(offset, length), srcSize);
580 newLengths.push_back(newLength);
586 }
else if (!hasZeroLen) {
588 loc, arith::CmpIPredicate::eq,
593 ? b.
create<arith::OrIOp>(loc, check, dynHasZeroLenCond)
602 hasHighPad ? sub(sub(length, newLength), newLow) : zero;
603 newHighs.push_back(newHigh);
613 RankedTensorType resultType =
618 if (resultType == val.getType())
620 return b.
create<tensor::CastOp>(loc, resultType, val);
626 auto createGenerateOp = [&]() {
628 auto generateOp = b.
create<tensor::GenerateOp>(
629 loc, resultType, dynDims,
631 builder.create<tensor::YieldOp>(gLoc, padValue);
638 auto createPadOfExtractSlice = [&]() {
640 Value newSliceOp = b.
create<tensor::ExtractSliceOp>(
641 loc, padOp.getSource(), newOffsets, newLengths, newStrides);
642 auto newPadOp = b.
create<PadOp>(
643 loc,
Type(), newSliceOp, newLows, newHighs,
649 padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm);
658 Operation *generateOp = createGenerateOp();
664 if (generateZeroSliceGuard && dynHasZeroLenCond) {
667 auto result = b.
create<scf::IfOp>(
668 loc, dynHasZeroLenCond,
671 thenOp = createGenerateOp();
672 b.create<scf::YieldOp>(loc, castResult(thenOp->
getResult(0)));
676 elseOp = createPadOfExtractSlice();
677 b.create<scf::YieldOp>(loc, castResult(elseOp->
getResult(0)));
682 Operation *newPadOp = createPadOfExtractSlice();
689 tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
690 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
691 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
698 tensor::PackOp::attachInterface<PackOpTiling>(*ctx);
699 tensor::UnPackOp::attachInterface<UnPackOpTiling>(*ctx);
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
result_range getResults()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, const Variable &var, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given variable.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs)
Create IR to calculate (div lhs, rhs) and (mod lhs, rhs).
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
void registerTilingInterfaceExternalModels(mlir::DialectRegistry ®istry)
Registers external models for Tiling interface for tensor ops.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
void registerTilingInterfaceExternalModelsForPackUnPackOps(DialectRegistry ®istry)
Similar to the above registeration, but it is only for tensor.pack and tensor.unpack ops.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
Container for result values of tiling.
Helper struct to build simple AffineValueExprs with minimal type inference support.
Holds the result of (div a, b) and (mod a, b).