MLIR  18.0.0git
Padding.cpp
Go to the documentation of this file.
1 //===- Padding.cpp - Padding of Linalg ops --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 
17 #define DEBUG_TYPE "linalg-padding"
18 
19 using namespace mlir;
20 using namespace mlir::linalg;
21 
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
23 #define DBGSNL() (llvm::dbgs() << "\n")
24 
25 /// Compute the padded shape of the given operand. The operand is padded to a
26 /// static bounding box according to the specified options.
27 static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
28  OpOperand *opOperand,
30  SmallVector<int64_t> &paddedShape,
31  bool &alreadyHasRequestedShape) {
32  AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand);
33  ArrayRef<int64_t> shape = opToPad.getShape(opOperand);
34 
35  // Collect the shape dimensions that are a function of "paddingDimensions",
36  // along with the multiple that they should be padded to ("1" if none).
37  alreadyHasRequestedShape = true;
38  DenseMap<int64_t, int64_t> shapeDimToMultiple;
39  for (const auto &dimEn : enumerate(options.paddingDimensions)) {
40  for (const auto &en : enumerate(indexingMap.getResults())) {
41  if (en.value().isFunctionOfDim(dimEn.value())) {
42  int64_t dimSize = shape[en.index()];
43  if (options.padToMultipleOf.has_value()) {
44  shapeDimToMultiple[en.index()] =
45  (*options.padToMultipleOf)[dimEn.index()];
46  } else {
47  shapeDimToMultiple[en.index()] = 1;
48  }
49  if (ShapedType::isDynamic(dimSize)) {
50  alreadyHasRequestedShape = false;
51  } else if (dimSize % shapeDimToMultiple[en.index()] != 0) {
52  alreadyHasRequestedShape = false;
53  }
54  }
55  }
56  }
57 
58  // Helper function to round a number up to a given multiple.
59  auto ceil = [](int64_t val, int64_t multiple) {
60  return ((val + multiple - 1) / multiple) * multiple;
61  };
62 
63  // Upper bound the sizes to obtain a static bounding box.
64  paddedShape.assign(shape.begin(), shape.end());
65  for (int64_t i = 0, e = shape.size(); i < e; ++i) {
66  LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n");
67  // Skip dimensions that do not require padding.
68  if (!shapeDimToMultiple.contains(i)) {
69  LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n");
70  continue;
71  }
72  // Otherwise, try to compute a constant upper bound for the size value.
73  FailureOr<int64_t> upperBound =
75  presburger::BoundType::UB, opOperand->get(),
76  /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
77  if (failed(upperBound)) {
78  LLVM_DEBUG(DBGS() << "----count not compute a bounding box for padding");
79  return failure();
80  }
81  paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]);
82  LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n");
83  }
84 
85  return success();
86 }
87 
88 /// Pad the `opOperand` in the "paddingDimensions" using the padding value and
89 /// the nofold flag found in "paddingValues" and "packPaddings", respectively.
90 ///
91 /// Exit early and return the `opOperand` value if it already has the requested
92 /// shape. I.e.:
93 /// - static shape
94 /// - nofold is not set
95 /// - dim sizes are multiples of "padToMultipleOf"
96 ///
97 /// Otherwise, try to pad the shape dimensions that match the iterator
98 /// dimensions "paddingDimensions" and return the tensor::PadOp result if
99 /// padding succeeds or failure otherwise.
101  RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand,
102  const LinalgPaddingOptions &options) {
103  assert(
104  (!options.padToMultipleOf.has_value() ||
105  options.padToMultipleOf->size() == options.paddingDimensions.size()) &&
106  "invalid number of elements in padToMultipleOf");
107 
108  // Compute padded shape.
109  SmallVector<int64_t> paddedShape;
110  bool alreadyHasRequestedShape = false;
111  if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape,
112  alreadyHasRequestedShape)))
113  return rewriter.notifyMatchFailure(opToPad,
114  "--failed to compute padded shape");
115 
116  // Return the unpadded operand if padding to a static shape is not needed and
117  // if the nofold flag is not set.
118  bool nofold = opOperand->getOperandNumber() < options.packPaddings.size()
119  ? options.packPaddings[opOperand->getOperandNumber()]
120  : false;
121  if (!nofold && alreadyHasRequestedShape)
122  return opOperand->get();
123 
124  // Fail if `paddingValues` specifies no padding value.
125  if (opOperand->getOperandNumber() >= options.paddingValues.size()) {
126  return rewriter.notifyMatchFailure(opToPad, "--no padding value specified");
127  }
128  Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()];
129 
130  Value paddingValue;
131  if (auto complexTy = dyn_cast<ComplexType>(
132  getElementTypeOrSelf(opOperand->get().getType()))) {
133  auto complexAttr = cast<ArrayAttr>(paddingAttr);
134  paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
135  complexTy, complexAttr);
136  } else {
137  paddingValue = rewriter.create<arith::ConstantOp>(
138  opToPad.getLoc(), cast<TypedAttr>(paddingAttr));
139  }
140 
141  // Pad the operand to the bounding box defined by `paddedShape`.
142  auto paddedTensorType = RankedTensorType::get(
143  paddedShape, getElementTypeOrSelf(opOperand->get()));
144  LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
145  << paddedTensorType);
146  return makeComposedPadHighOp(rewriter, opToPad->getLoc(), paddedTensorType,
147  opOperand->get(), paddingValue, nofold);
148 }
149 
151 linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
152  const LinalgPaddingOptions &constOptions,
153  LinalgOp &paddedOp, SmallVector<Value> &replacements,
154  SmallVector<tensor::PadOp> &padOps) {
155  LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
156  Location loc = opToPad->getLoc();
157 
158  LinalgPaddingOptions options(constOptions);
159  // Allow inference of pad values if they are not explicitly specified.
160  // TODO: be mindful about the value depending on the actual operation.
161  if (options.paddingValues.empty()) {
162  SmallVector<Type> types(opToPad->getOperandTypes());
163  llvm::append_range(types, opToPad->getResultTypes());
164  for (Type t : types) {
165  options.paddingValues.push_back(
166  rewriter.getZeroAttr(getElementTypeOrSelf(t)));
167  }
168  }
169 
170  // TODO: there are cases where we may still want to pad to larger sizes.
171  if (!opToPad.hasTensorSemantics())
172  return rewriter.notifyMatchFailure(opToPad,
173  "expected operation on tensors");
174 
175  OpBuilder::InsertionGuard g(rewriter);
176  // Set IP after op because we also take the dims of the original output.
177  rewriter.setInsertionPointAfter(opToPad);
178 
179  // Make a copy of the shaped operands and update it.
180  SmallVector<Value> newOperands;
181  newOperands.reserve(opToPad->getNumOperands());
182  for (OpOperand &opOperand : opToPad->getOpOperands()) {
184  rewriter, opToPad, &opOperand, options);
185  // Exit if `paddingDimensions` cannot be bounded statically.
186  if (failed(paddedOperand)) {
187  LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : "
188  << opOperand.get() << " -> FAIL\n");
189  return rewriter.notifyMatchFailure(opToPad,
190  "operand cannot be bound statically");
191  }
192  newOperands.push_back(*paddedOperand);
193  if (auto padOp = paddedOperand->getDefiningOp<tensor::PadOp>())
194  padOps.push_back(padOp);
195  }
196 
197  ReifiedRankedShapedTypeDims reifiedResultShapes;
198  if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
199  LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
200  return rewriter.notifyMatchFailure(opToPad,
201  "failed to reify result shapes");
202  }
203  assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
204  "expected same number of results");
205 
206  // Clone `opToPad` to operate on the statically padded shapes.
207  auto resultTensorTypes =
208  ValueRange(newOperands).take_back(opToPad.getNumDpsInits()).getTypes();
209  // clone **should** properly notify the rewriter.
210  paddedOp = clone(rewriter, opToPad, resultTensorTypes, newOperands);
211  LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
212 
213  // Recover the slice out of the new static results. This keeps the original
214  // linalg op around because it uses the dims of the original results.
215  SmallVector<Value> paddedSubtensorResults;
216  paddedSubtensorResults.reserve(opToPad->getNumResults());
217  for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
218  Value paddedResult = en.value();
219  int64_t resultNumber = en.index();
220  int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
221  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
222  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
223  paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
224  loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
225  strides));
226  }
227 
229  replacements = std::move(paddedSubtensorResults);
230  return success();
231  }
232 
233  // Copy back unpadded results to the original destination (i.e., inits of the
234  // linalg op), so that the destination buffer of the computation does not
235  // change. If the padding folds away, this will materizalize as a memcpy
236  // between two identical buffers, which will then also fold away.
237  assert(static_cast<int64_t>(paddedSubtensorResults.size()) ==
238  opToPad.getNumDpsInits() &&
239  "expected matching number of results");
240  for (auto it :
241  llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) {
243  replacements.push_back(rewriter
244  .create<linalg::CopyOp>(loc, std::get<0>(it),
245  std::get<1>(it).get())
246  .getResult(0));
247  } else if (options.copyBackOp ==
249  BufferizationMaterializeInDestination) {
250  replacements.push_back(
251  rewriter
252  .create<bufferization::MaterializeInDestinationOp>(
253  loc, std::get<0>(it), std::get<1>(it).get())
254  ->getResult(0));
255  } else {
256  llvm_unreachable("unsupported copy back op");
257  }
258  }
259  return success();
260 }
261 
263 mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp,
264  const LinalgPaddingOptions &options) {
265  assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None &&
266  "invalid options");
267 
268  if (!linalgOp.hasTensorSemantics())
269  return rewriter.notifyMatchFailure(
270  linalgOp, "only applies to Linalg ops with tensor semantics");
271 
272  // Pad the operation.
273  LinalgOp paddedOp;
274  SmallVector<Value> newResults;
276  if (failed(rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp,
277  newResults, padOps)))
278  return rewriter.notifyMatchFailure(linalgOp,
279  "failed to rewrite as a padded op");
280 
281  // Hoist the padding.
282  for (const auto &en : enumerate(options.hoistPaddings)) {
283  if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands())
284  break;
285  OpOperand &opOperand = paddedOp->getOpOperand(en.index());
286  auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
287  if (!padOp || en.value() == 0) {
288  (void)rewriter.notifyMatchFailure(linalgOp, "not a tensor.pad -- skip");
289  continue;
290  }
291 
292  // Fail hoisting if the operand shape is not fully static.
293  if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) {
294  (void)rewriter.notifyMatchFailure(linalgOp,
295  "non static padding shape -- skip");
296  continue;
297  }
298 
299  tensor::PadOp hoistedOp;
300  SmallVector<GenericOp> transposeOps;
301  SmallVector<int64_t> transposeVector =
302  en.index() < options.transposePaddings.size()
303  ? options.transposePaddings[en.index()]
305 
307  padOp, en.value(), transposeVector, hoistedOp, transposeOps);
308  if (failed(newResult)) {
309  (void)rewriter.notifyMatchFailure(linalgOp,
310  "failed to apply hoistPadding");
311  continue;
312  }
313  rewriter.replaceOp(padOp, *newResult);
314  }
315 
316  // Replace the original operation to pad.
317  rewriter.replaceOp(linalgOp, newResults);
318 
319  return paddedOp;
320 }
static FailureOr< Value > padOperandToSmallestStaticBoundingBox(RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, const LinalgPaddingOptions &options)
Pad the opOperand in the "paddingDimensions" using the padding value and the nofold flag found in "pa...
Definition: Padding.cpp:100
static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, OpOperand *opOperand, const LinalgPaddingOptions &options, SmallVector< int64_t > &paddedShape, bool &alreadyHasRequestedShape)
Compute the padded shape of the given operand.
Definition: Padding.cpp:27
#define DBGS()
Definition: Padding.cpp:22
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:387
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static FailureOr< int64_t > computeConstantBound(presburger::BoundType type, Value value, std::optional< int64_t > dim=std::nullopt, StopConditionFn stopCondition=nullptr, bool closedUB=false)
Compute a constant bound for the given affine map, where dims and symbols are bound to the given oper...
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions paddingDimensions of all opToPad operands to a static bounding box.
Definition: Padding.cpp:151
FailureOr< Value > hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, int64_t numLoops, ArrayRef< int64_t > transposeVector, tensor::PadOp &hoistedOp, SmallVectorImpl< GenericOp > &transposeOps)
Mechanically hoist padding operations on tensors by numLoops into a new, generally larger tensor.
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold)
Create a tensor::PadOp that pads source to the size of the statically sized type whose static sizes a...
Definition: Utils.cpp:192
FailureOr< LinalgOp > padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, const LinalgPaddingOptions &options)
Apply padding and hoisting to linalgOp according to the configuration specified in options.
Definition: Padding.cpp:263
MPInt ceil(const Fraction &f)
Definition: Fraction.h:76
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26