MLIR  22.0.0git
PadTilingInterface.cpp
Go to the documentation of this file.
1 //===- PaddingTilingInterface.cpp - Padding of TilingInterface ops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/OpDefinition.h"
20 #include "mlir/IR/Value.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/Casting.h"
24 
25 #define DEBUG_TYPE "pad-tiling-interface"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 using namespace mlir::tensor;
30 
31 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
32 #define DBGSNL() (llvm::dbgs() << "\n")
33 
34 /// Form a "full-rank" padding specification so that the application is easy.
38  SmallVector<OpFoldResult> paddingSizes;
39  // Complete the padding specification to specify all dimensions.
40  for (size_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
41  // Complete to zero if needed.
42  paddingSizes.push_back(options.paddingSizes.size() > idx
43  ? options.paddingSizes[idx]
44  : b.getIndexAttr(0));
45  // If a dimension is zero (either specified or completed), replace by:
46  // - 1 if we are padding to the next multiple of.
47  // - indexingSizes[idx] otherwise
48  if (isZeroInteger(paddingSizes[idx])) {
49  paddingSizes[idx] =
50  options.padToMultipleOf ? b.getIndexAttr(1) : indexingSizes[idx];
51  }
52  LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << paddingSizes[idx]
53  << "\n");
54  }
55  return paddingSizes;
56 }
57 
58 /// Extracts the constant multiplier from an affine expression of the form
59 /// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
60 /// AffineConstantExpr. Returns 1 if the expression is not a simple
61 /// multiplication of a dimension and a constant.
62 static int64_t extractConstantMultiplier(AffineExpr expr) {
63  if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
64  if (binOp.getKind() == AffineExprKind::Mul) {
65  auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
66  auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
67  if (lhsD && rhsC) {
68  return rhsC.getValue();
69  }
70  auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
71  auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
72  if (lhsC && rhsD) {
73  return lhsC.getValue();
74  }
75  }
76  }
77  return 1;
78 }
79 
80 /// Compute the padded shape of the given value `v` of `RankedTensorType` given
81 /// - `indexingSizes` a list of OpFoldResult.
82 /// - an `indexingMap` that encodes how the shape of varies with increases
83 /// in `indexingSizes`.
84 /// The `indexingMap` encodes how the shape of varies with `indexingSizes`.
85 /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
86 /// The implementaiton below iteratively combines increases from contributing
87 /// dimensions using affine.apply operations.
88 /// The padded shape is computed by evaluating the maximum accessed index per
89 /// dimension, which may involve multiplying by constant factors derived from
90 /// the affine indexing expressions. Currently, only a limited set of projected
91 /// permutation indexing maps are supported, such as
92 /// - affine_map<(d0, d1, d2) -> (d0, d1)>
93 /// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
94 /// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
95 /// In the future, more general interfaces can be devised to encode similar
96 /// shape evolutions and map between an op and its operands.
99  AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
101  Location loc = v.getLoc();
102  SmallVector<OpFoldResult> paddedShape;
103  auto tensorType = cast<RankedTensorType>(v.getType());
104  paddedShape.resize_for_overwrite(tensorType.getRank());
105  assert(tensorType.getRank() == indexingMap.getNumResults() &&
106  "expect the number of results of the affine map to match the tensor "
107  "rank");
108 
109  // "Full-rank" padding specification.
110  SmallVector<OpFoldResult> paddingSizes =
111  getFullRankPaddingSizes(rewriter, indexingSizes, options);
112 
113  // For each dimension in the operand's shape, iterate over indexingSizes and
114  // add the various term contributions.
115  for (const auto &enResults : enumerate(indexingMap.getResults())) {
116  int64_t resultIndex = enResults.index();
117  AffineMap partialIndexingMap = indexingMap.getSubMap(
118  ArrayRef<unsigned>{static_cast<unsigned>(resultIndex)});
119 
120  LLVM_DEBUG(DBGS() << "----resultIndex: " << resultIndex
121  << " with partialIndexingMap: " << partialIndexingMap
122  << "\n");
123 
124  // Find all padding dimensions that contribute to this operand dimension
125  // and compute the padded term contribution to the final padded shape.
127  for (size_t paddingDim = 0, e = paddingSizes.size(); paddingDim != e;
128  ++paddingDim) {
129  OpFoldResult paddingSize = paddingSizes[paddingDim];
130  LLVM_DEBUG(DBGS() << "------try apply padding of dim: " << paddingDim
131  << " to: " << paddingSize << "\n");
132  if (!enResults.value().isFunctionOfDim(paddingDim))
133  continue;
134 
135  LLVM_DEBUG(DBGS() << "------apply padding of dim: " << paddingDim
136  << " to: " << paddingSize << "\n");
137 
138  // Project non-'paddingDim' dimensions and compress the result.
139  llvm::SmallBitVector projectedDims(partialIndexingMap.getNumDims(), true);
140  projectedDims.flip(paddingDim);
141  AffineMap projectedMap =
142  mlir::projectDims(partialIndexingMap, projectedDims,
143  /*compressDims=*/true);
144 
145  // If we are padding to the next multiple of, compose with ceil(sz) * sz.
146  OpFoldResult paddingDimOfr;
147  if (options.padToMultipleOf) {
148  AffineExpr d0, s0;
149  bindDims(rewriter.getContext(), d0);
150  bindSymbols(rewriter.getContext(), s0);
151  AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
152  AffineMap composedMap = projectedMap.compose(ceilMap);
154  rewriter, loc, composedMap,
155  {indexingSizes[paddingDim], paddingSize},
156  /*composeAffineMin=*/true);
157  } else {
158  // Otherwise just set to paddingSize.
160  rewriter, loc, projectedMap, paddingSize);
161  }
162 
163  // Adjust for the maximum accessed index, which is (paddingSize - 1) *
164  // multiplier.
165  AffineExpr d0;
166  bindDims(rewriter.getContext(), d0);
167  int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
168  AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
170  rewriter, loc, subtractMap, {paddingDimOfr});
171  terms.push_back(maxAccessIdx);
172 
173  LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
174  }
175 
176  // If there are no terms, just return the dim.
177  if (terms.empty()) {
178  paddedShape[resultIndex] =
179  createFoldedDimOp(rewriter, loc, v, resultIndex);
180  continue;
181  }
182 
183  // Sum individual terms' contributions.
184  SmallVector<AffineExpr> dims(terms.size());
185  bindDimsList(rewriter.getContext(), MutableArrayRef{dims});
186  AffineExpr sumExpr = dims.front();
187  for (unsigned i = 1; i < dims.size(); ++i)
188  sumExpr = sumExpr + dims[i];
189  // Add 1 to the maximum accessed index and get the final padded size.
191  rewriter, loc, sumExpr + 1, terms);
192  paddedShape[resultIndex] = paddedDimOfr;
193  }
194 
195  return paddedShape;
196 }
197 
198 FailureOr<SmallVector<OpFoldResult>>
200  RewriterBase &rewriter, OpOperand &operandToPad,
201  ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
202  auto transferOp =
203  llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
204  if (!transferOp)
205  return failure();
206 
207  // clang-format off
208  assert(llvm::all_of(iterationDomain, [&rewriter](Range r) {
209  return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) &&
210  r.stride == OpFoldResult(rewriter.getIndexAttr(1));
211  }) && "expected 0-offset 1-stride loop ranges");
212  // clang-format on
213  SmallVector<OpFoldResult> loopUpperBounds;
214  loopUpperBounds.reserve(iterationDomain.size());
215  for (const Range &range : iterationDomain)
216  loopUpperBounds.push_back(range.size);
217 
218  AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
219  return computePaddedShape(
220  rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
221  indexingMap, loopUpperBounds, options);
222 }
223 
224 /// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
225 /// Value.
226 static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
228  ArrayRef<OpFoldResult> paddedShape,
229  Attribute paddingValueAttr) {
230  Value paddingValue;
231  if (auto complexTy =
232  dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
233  auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
234  paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
235  complexTy, complexAttr);
236  } else {
237  paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
238  cast<TypedAttr>(paddingValueAttr));
239  }
240 
241  // Pad the operand to the bounding box defined by `paddedShape`.
242  SmallVector<int64_t> tensorShape;
243  SmallVector<Value> dynDims;
244  for (OpFoldResult ofr : paddedShape) {
245  std::optional<int64_t> cst = getConstantIntValue(ofr);
246  tensorShape.push_back(cst.has_value() ? *cst : ShapedType::kDynamic);
247  if (!cst.has_value())
248  dynDims.push_back(ofr.dyn_cast<Value>());
249  }
250  // TODO: use dispatchIndexOpFoldResults(paddedShape, dynDims, paddedShape);
251 
252  auto paddedTensorType =
254  LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
255  << paddedTensorType);
256  return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v,
257  paddingValue, /*nofold=*/false, dynDims);
258 }
259 
260 FailureOr<TilingInterface>
261 linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
262  const PadTilingInterfaceOptions &constOptions,
264  PadSizeComputationFunction computePaddingSizeFun) {
265  LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
266 
267  Location loc = opToPad.getLoc();
268  PadTilingInterfaceOptions options(constOptions);
269  // Allow inference of pad values if they are not explicitly specified.
270  // TODO: be mindful about the value depending on the actual operation.
271  if (options.paddingValues.empty()) {
272  SmallVector<Type> types(opToPad->getOperandTypes());
273  llvm::append_range(types, opToPad->getResultTypes());
274  for (Type t : types) {
275  options.paddingValues.push_back(
276  rewriter.getZeroAttr(getElementTypeOrSelf(t)));
277  }
278  }
279 
280  if (llvm::any_of(opToPad->getOperands(),
281  [](Value v) { return isa<MemRefType>(v.getType()); })) {
282  return rewriter.notifyMatchFailure(opToPad,
283  "expected operation on tensors");
284  }
285 
286  OpBuilder::InsertionGuard g(rewriter);
287  // Set IP after opToPad because we also take the dims of opToPad's output.
288  rewriter.setInsertionPointAfter(opToPad);
289 
290  // 1. Get the loopUpperBounds from the TilingInterface.
291  SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
292 
293  // 2. For each operand.
294  SmallVector<Value> newOperands;
295  newOperands.reserve(opToPad->getNumOperands());
296  for (OpOperand &opOperand : opToPad->getOpOperands()) {
297  Value operand = opOperand.get();
298  LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
299 
300  // 2.a. Skip scalar-like operands.
301  Type operandType = operand.getType();
302  if (!isa<RankedTensorType>(operandType)) {
303  assert((!isa<ShapedType>(operandType) || isa<VectorType>(operandType)) &&
304  "Unexpected non-vector ShapedType");
305  newOperands.push_back(operand);
306  continue;
307  }
308  // 2.a. Compute padded shape.
309  FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
310  computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
311  if (failed(maybePaddedShape)) {
312  return rewriter.notifyMatchFailure(opToPad, "could not pad op");
313  }
314 
315  // 2.b. Expect proper `paddingValues`.
316  // TODO: we may want to allow garbage padding in the future, in which case
317  // we would just not assert.
318  if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
319  return rewriter.notifyMatchFailure(opToPad,
320  "--no padding value specified");
321  }
322  Attribute paddingValueAttr =
323  options.paddingValues[opOperand.getOperandNumber()];
324 
325  // 2.c. Perform actual padding.
326  Value paddedOperand = padOperand(
327  rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
328  *maybePaddedShape, paddingValueAttr);
329  LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
330 
331  // 2.d. Perform actual padding.
332  newOperands.push_back(paddedOperand);
333  if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>())
334  padOps.push_back(padOp);
335  }
336 
337  // 3. Form the resulting tensor::ExtractSliceOp.
338  ReifiedRankedShapedTypeDims reifiedResultShapes;
339  if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
340  LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
341  return rewriter.notifyMatchFailure(opToPad,
342  "failed to reify result shapes");
343  }
344  assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
345  "expected same number of results");
346 
347  // Clone `opToPad` to operate on the statically padded shapes.
348  auto resultTensorTypes =
349  ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
350  // clone **should** properly notify the rewriter.
351  TilingInterface paddedOp =
352  clone(rewriter, opToPad, resultTensorTypes, newOperands);
353  LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
354 
355  // Recover the slice out of the new static results. This keeps the original
356  // opToPad around because it uses the dims of the original results.
357  SmallVector<Value> paddedSubtensorResults;
358  paddedSubtensorResults.reserve(opToPad->getNumResults());
359  for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
360  Value paddedResult = en.value();
361  int64_t resultNumber = en.index();
362  int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
363  SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
364  SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
365  paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
366  rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
367  strides));
368  }
369 
370  rewriter.replaceOp(opToPad, paddedSubtensorResults);
371 
372  return paddedOp;
373 }
static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, TypedValue< RankedTensorType > v, ArrayRef< OpFoldResult > paddedShape, Attribute paddingValueAttr)
Pad a single operand to paddedShape using paddingValueAttr as padding Value.
#define DBGS()
static int64_t extractConstantMultiplier(AffineExpr expr)
Extracts the constant multiplier from an affine expression of the form d * c or c * d,...
static SmallVector< OpFoldResult > getFullRankPaddingSizes(Builder &b, ArrayRef< OpFoldResult > indexingSizes, const PadTilingInterfaceOptions &options)
Form a "full-rank" padding specification so that the application is easy.
static llvm::ManagedStatic< PassManagerOptions > options
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:647
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
MLIRContext * getContext() const
Definition: Builders.h:55
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1327
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, const LinalgPaddingOptions &options, LinalgOp &paddedOp, SmallVector< Value > &replacements, SmallVector< tensor::PadOp > &padOps)
Pad the iterator dimensions options.paddingDimensions of all opToPad operands to a static bounding bo...
Definition: Padding.cpp:244
SmallVector< OpFoldResult > computePaddedShape(RewriterBase &rewriter, TypedValue< RankedTensorType > v, AffineMap indexingMap, ArrayRef< OpFoldResult > indexingSizes, const PadTilingInterfaceOptions &options)
Helper function to compute the padded shape of the given value v of RankedTensorType given:
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:104
FailureOr< SmallVector< OpFoldResult > > computeIndexingMapOpInterfacePaddedShape(RewriterBase &rewriter, OpOperand &operandToPad, ArrayRef< Range > iterationDomain, const PadTilingInterfaceOptions &options)
Specific helper for Linalg ops.
std::function< FailureOr< SmallVector< OpFoldResult > >(RewriterBase &, OpOperand &, ArrayRef< Range >, const PadTilingInterfaceOptions &)> PadSizeComputationFunction
Definition: Transforms.h:631
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
Definition: Utils.cpp:243
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
void bindDimsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:316
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineMap projectDims(AffineMap map, const llvm::SmallBitVector &projectedDimensions, bool compressDimsFlag=false)
Returns the map that results from projecting out the dimensions specified in projectedDimensions.
Definition: AffineMap.cpp:899
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult offset