17 #define DEBUG_TYPE "linalg-padding"
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
23 #define DBGSNL() (llvm::dbgs() << "\n")
31 bool &alreadyHasRequestedShape) {
32 AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
37 alreadyHasRequestedShape =
true;
41 if (en.value().isFunctionOfDim(dimEn.value())) {
42 int64_t dimSize = shape[en.index()];
43 if (
options.padToMultipleOf.has_value()) {
44 shapeDimToMultiple[en.index()] =
45 (*
options.padToMultipleOf)[dimEn.index()];
47 shapeDimToMultiple[en.index()] = 1;
49 if (ShapedType::isDynamic(dimSize)) {
50 alreadyHasRequestedShape =
false;
51 }
else if (dimSize % shapeDimToMultiple[en.index()] != 0) {
52 alreadyHasRequestedShape =
false;
59 auto ceil = [](int64_t val, int64_t multiple) {
60 return ((val + multiple - 1) / multiple) * multiple;
64 paddedShape.assign(shape.begin(), shape.end());
65 for (int64_t i = 0, e = shape.size(); i < e; ++i) {
66 LLVM_DEBUG(
DBGS() <<
"--compute padded size for dim " << i <<
"\n");
68 if (!shapeDimToMultiple.contains(i)) {
69 LLVM_DEBUG(
DBGS() <<
"----dim does not require padding, SKIP\n");
78 LLVM_DEBUG(
DBGS() <<
"----count not compute a bounding box for padding");
81 paddedShape[i] =
ceil(*upperBound, shapeDimToMultiple[i]);
82 LLVM_DEBUG(
DBGS() <<
"----new dim size: " << paddedShape[i] <<
"\n");
104 (!
options.padToMultipleOf.has_value() ||
105 options.padToMultipleOf->size() ==
options.paddingDimensions.size()) &&
106 "invalid number of elements in padToMultipleOf");
110 bool alreadyHasRequestedShape =
false;
112 alreadyHasRequestedShape)))
114 "--failed to compute padded shape");
121 if (!nofold && alreadyHasRequestedShape)
122 return opOperand->
get();
131 if (
auto complexTy = dyn_cast<ComplexType>(
133 auto complexAttr = cast<ArrayAttr>(paddingAttr);
134 paddingValue = rewriter.
create<complex::ConstantOp>(opToPad.getLoc(),
135 complexTy, complexAttr);
137 paddingValue = rewriter.
create<arith::ConstantOp>(
138 opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
144 LLVM_DEBUG(
DBGS() <<
"--SUCCESS, makeComposedPadHighOp with type: "
145 << paddedTensorType);
147 opOperand->
get(), paddingValue, nofold);
155 LLVM_DEBUG(
DBGS() <<
"Start rewriteAsPaddedOp : " << opToPad <<
"\n");
161 if (
options.paddingValues.empty()) {
163 llvm::append_range(types, opToPad->getResultTypes());
164 for (
Type t : types) {
165 options.paddingValues.push_back(
171 if (!opToPad.hasTensorSemantics())
173 "expected operation on tensors");
181 newOperands.reserve(opToPad->getNumOperands());
182 for (
OpOperand &opOperand : opToPad->getOpOperands()) {
184 rewriter, opToPad, &opOperand,
options);
186 if (
failed(paddedOperand)) {
187 LLVM_DEBUG(
DBGS() <<
"--operand cannot be bound statically : "
188 << opOperand.get() <<
" -> FAIL\n");
190 "operand cannot be bound statically");
192 newOperands.push_back(*paddedOperand);
193 if (
auto padOp = paddedOperand->getDefiningOp<tensor::PadOp>())
194 padOps.push_back(padOp);
199 LLVM_DEBUG(
DBGS() <<
"--failed to reify result shapes -> FAIL\n");
201 "failed to reify result shapes");
203 assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
204 "expected same number of results");
207 auto resultTensorTypes =
210 paddedOp =
clone(rewriter, opToPad, resultTensorTypes, newOperands);
211 LLVM_DEBUG(
DBGS() <<
"--cloned padded op: " << paddedOp <<
"\n");
216 paddedSubtensorResults.reserve(opToPad->getNumResults());
218 Value paddedResult = en.value();
219 int64_t resultNumber = en.index();
220 int64_t rank = cast<RankedTensorType>(paddedResult.
getType()).getRank();
223 paddedSubtensorResults.push_back(rewriter.
create<tensor::ExtractSliceOp>(
224 loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
229 replacements = std::move(paddedSubtensorResults);
237 assert(
static_cast<int64_t
>(paddedSubtensorResults.size()) ==
238 opToPad.getNumDpsInits() &&
239 "expected matching number of results");
241 llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) {
243 replacements.push_back(rewriter
244 .create<linalg::CopyOp>(loc, std::get<0>(it),
245 std::get<1>(it).
get())
247 }
else if (
options.copyBackOp ==
249 BufferizationMaterializeInDestination) {
250 replacements.push_back(
252 .create<bufferization::MaterializeInDestinationOp>(
253 loc, std::get<0>(it), std::get<1>(it).
get())
256 llvm_unreachable(
"unsupported copy back op");
268 if (!linalgOp.hasTensorSemantics())
270 linalgOp,
"only applies to Linalg ops with tensor semantics");
277 newResults, padOps)))
279 "failed to rewrite as a padded op");
283 if (
static_cast<int64_t
>(en.index()) >= paddedOp->getNumOperands())
285 OpOperand &opOperand = paddedOp->getOpOperand(en.index());
287 if (!padOp || en.value() == 0) {
293 if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) {
295 "non static padding shape -- skip");
299 tensor::PadOp hoistedOp;
302 en.index() <
options.transposePaddings.size()
303 ?
options.transposePaddings[en.index()]
307 padOp, en.value(), transposeVector, hoistedOp, transposeOps);
310 "failed to apply hoistPadding");
317 rewriter.
replaceOp(linalgOp, newResults);
static FailureOr< Value > padOperandToSmallestStaticBoundingBox(RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, const LinalgPaddingOptions &options)
Pad the opOperand in the "paddingDimensions" using the padding value and the nofold flag found in "pa...
static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, OpOperand *opOperand, const LinalgPaddingOptions &options, SmallVector< int64_t > &paddedShape, bool &alreadyHasRequestedShape)
Compute the padded shape of the given operand.
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
ArrayRef< AffineExpr > getResults() const
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, Value value, std::optional< int64_t > dim=std::nullopt, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given affine map, where dims and symbols are bound to the given oper...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold)
Create a tensor::PadOp that pads source to the size of the statically sized type whose static sizes a...
FailureOr< LinalgOp > padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, const LinalgPaddingOptions &options)
Apply padding and hoisting to linalgOp according to the configuration specified in options.
MPInt ceil(const Fraction &f)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.