MLIR  20.0.0git
RewriteAsConstant.cpp
Go to the documentation of this file.
1 //===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::tensor;
19 
20 namespace {
21 
22 /// Rewrite tensor.generate with arith.constant if the yielded value is a
23 /// constant and the tensor type is static.
24 struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
26 
27  LogicalResult matchAndRewrite(GenerateOp generateOp,
28  PatternRewriter &rewriter) const override {
29  auto tensorType =
30  llvm::cast<RankedTensorType>(generateOp.getResult().getType());
31  if (!tensorType.hasStaticShape())
32  return failure();
33  auto terminatorOp =
34  cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
35  Attribute attr;
36  if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
37  return failure();
38  Operation *constantOp =
39  rewriter.getContext()
40  ->getLoadedDialect<TensorDialect>()
41  ->materializeConstant(rewriter,
42  DenseElementsAttr::get(tensorType, attr),
43  tensorType, generateOp->getLoc());
44  if (!constantOp)
45  return failure();
46  rewriter.replaceOp(generateOp, constantOp->getResults());
47  return success();
48  }
49 };
50 
51 /// Transform a linear index from one indexing space to another given:
52 ///
53 /// - the shape of the source indexing space,
54 /// - the strides of the target indexing space,
55 /// - a linear index into the source indexing space.
56 ///
57 /// This function is logically a sequence of linearize/delinearize over
58 /// different bases but avoids allocating intermediate SmallVectors.
59 int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
60  ArrayRef<int64_t> outputStrides,
61  int64_t srcLinearIndex) {
62  assert(inputShape.size() == outputStrides.size());
63 
64  int64_t dstLinearIndex = 0;
65 
66  for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
67  // Compute the index into the current dimension of the source tensor.
68  // `quotient` is the remaining linear index after accounting for the
69  // current dimension.
70  //
71  // `remainder` is the index into the source tensor for the current
72  // dimension.
73  auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
74 
75  srcLinearIndex = quotient;
76 
77  // Add the contribution of the current dimension to the output using the
78  // permutation map.
79  dstLinearIndex += outputStrides[dim] * remainder;
80  }
81 
82  return dstLinearIndex;
83 }
84 
85 template <typename ElemType, typename AttrType>
86 Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
87  DenseElementsAttr input, AttrType padValue,
88  ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
89  auto inputValues = input.tryGetValues<ElemType>();
90  if (failed(inputValues))
91  return nullptr;
92 
93  auto oldShape = input.getType().getShape();
94 
95  // Compute the output shape of the new value.
96  auto newShape =
97  llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
98  [](std::tuple<int64_t, int64_t, int64_t> pack) {
99  auto [old, low, high] = pack;
100  return old + low + high;
101  });
102 
103  int64_t outputSize = computeProduct(newShape);
104 
105  // Fully initialize the vector with the padding value.
106  // The non-padded area will then be copied.
107  SmallVector<ElemType> values(outputSize, padValue.getValue());
108 
109  // Strides for input and output are used to transform between the indexing
110  // space of the input and output tensors.
111  SmallVector<int64_t> outputStrides = computeStrides(newShape);
112 
113  // The contribution of the low padding to the offset in the output tensor.
114  // This is the starting position of the source tensor within the padding
115  // tensor.
116  int64_t startingOffset = linearize(padLow, outputStrides);
117 
118  // Copy values from the input tensor to the corresponding sub-region
119  // of the output tensor.
120  for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
121  auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
122  values[outputIndex + startingOffset] = inputValue;
123  }
124 
125  // Create an attribute for the folded value.
126  auto newType = input.getType().clone(newShape);
127  auto newAttr = DenseElementsAttr::get(newType, values);
128 
129  Operation *constantOp =
130  rewriter.getContext()
131  ->getLoadedDialect<TensorDialect>()
132  ->materializeConstant(rewriter, newAttr, newType, loc);
133 
134  return constantOp ? constantOp->getResult(0) : nullptr;
135 }
136 
137 struct PadOpToConstant final : public OpRewritePattern<PadOp> {
138 
139  PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
140  PatternBenefit benefit = 1)
141  : OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
142 
143  LogicalResult matchAndRewrite(PadOp padTensorOp,
144  PatternRewriter &rewriter) const override {
145  if (padTensorOp.getNofold())
146  return rewriter.notifyMatchFailure(
147  padTensorOp, "refusing to fold nofold pad operation");
148 
149  TypedValue<RankedTensorType> input = padTensorOp.getSource();
150  RankedTensorType resultType = padTensorOp.getResult().getType();
151 
152  DenseElementsAttr inputAttr = nullptr;
153  if (!matchPattern(input, m_Constant(&inputAttr)))
154  return failure();
155 
156  Value paddingValue = padTensorOp.getConstantPaddingValue();
157 
158  // Extract the constant value used for padding or bail out.
159  Attribute paddingAttr = nullptr;
160  if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
161  return rewriter.notifyMatchFailure(padTensorOp,
162  "unable to get constant value");
163 
164  // Try to extract the constant values of the low and high padding.
165  auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
166  auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
167 
168  // If the padding cannot be extracted, bail out.
169  if (!lowPad || !highPad)
170  return rewriter.notifyMatchFailure(padTensorOp,
171  "unable to extract constant padding");
172 
173  // We have a potential candidate, consult the control function to
174  // determine if the op should fold.
175  if (!controlFn(&padTensorOp.getSourceMutable()))
176  return rewriter.notifyMatchFailure(padTensorOp,
177  "not folding due to cost function");
178 
179  Location loc = padTensorOp.getLoc();
180 
181  // Try constant folding the supported cases of integer and float values.
182  Value newOp =
184  .Case([&](FloatAttr floatAttr) {
185  return constantFoldPadOp<llvm::APFloat>(
186  rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
187  })
188  .Case([&](IntegerAttr integerAttr) {
189  return constantFoldPadOp<llvm::APInt>(
190  rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
191  })
192  .Default(Value());
193 
194  if (!newOp)
195  return rewriter.notifyMatchFailure(padTensorOp,
196  "tensor type not supported");
197 
198  if (newOp.getType() != resultType)
199  newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);
200 
201  rewriter.replaceOp(padTensorOp, newOp);
202  return success();
203  }
204 
205 private:
206  ControlFoldFn controlFn;
207 };
208 
209 } // namespace
210 
212  RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
213  patterns.add<GenerateToConstant>(patterns.getContext());
214 
215  patterns.add<PadOpToConstant>(patterns.getContext(), controlFn);
216 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Definition: Builders.h:55
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
FailureOr< iterator_range_impl< ElementIterator< T > > > tryGetValues() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:480
std::function< bool(OpOperand *)> ControlFoldFn
Definition: Transforms.h:94
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns, const ControlFoldFn &controlFn)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358