MLIR 22.0.0git
RewriteAsConstant.cpp
Go to the documentation of this file.
1//===- RewriteAsConstant.cpp - Patterns to rewrite tensor ops as constants ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
12#include "mlir/IR/Matchers.h"
14
15#include "llvm/ADT/TypeSwitch.h"
16
17using namespace mlir;
18using namespace mlir::tensor;
19
20namespace {
21
22/// Rewrite tensor.generate with arith.constant if the yielded value is a
23/// constant and the tensor type is static.
24struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
25 using OpRewritePattern<GenerateOp>::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(GenerateOp generateOp,
28 PatternRewriter &rewriter) const override {
29 auto tensorType =
30 llvm::cast<RankedTensorType>(generateOp.getResult().getType());
31 if (!tensorType.hasStaticShape())
32 return failure();
33 auto terminatorOp =
34 cast<tensor::YieldOp>(generateOp.getBody().front().getTerminator());
35 Attribute attr;
36 if (!matchPattern(terminatorOp.getValue(), m_Constant(&attr)))
37 return failure();
38 Operation *constantOp =
39 rewriter.getContext()
40 ->getLoadedDialect<TensorDialect>()
41 ->materializeConstant(rewriter,
42 DenseElementsAttr::get(tensorType, attr),
43 tensorType, generateOp->getLoc());
44 if (!constantOp)
45 return failure();
46 rewriter.replaceOp(generateOp, constantOp->getResults());
47 return success();
48 }
49};
50
51/// Transform a linear index from one indexing space to another given:
52///
53/// - the shape of the source indexing space,
54/// - the strides of the target indexing space,
55/// - a linear index into the source indexing space.
56///
57/// This function is logically a sequence of linearize/delinearize over
58/// different bases but avoids allocating intermediate SmallVectors.
59int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
60 ArrayRef<int64_t> outputStrides,
61 int64_t srcLinearIndex) {
62 assert(inputShape.size() == outputStrides.size());
63
64 int64_t dstLinearIndex = 0;
65
66 for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
67 // Compute the index into the current dimension of the source tensor.
68 // `quotient` is the remaining linear index after accounting for the
69 // current dimension.
70 //
71 // `remainder` is the index into the source tensor for the current
72 // dimension.
73 auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
74
75 srcLinearIndex = quotient;
76
77 // Add the contribution of the current dimension to the output using the
78 // permutation map.
79 dstLinearIndex += outputStrides[dim] * remainder;
80 }
81
82 return dstLinearIndex;
83}
84
85template <typename ElemType, typename AttrType>
86Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
87 DenseElementsAttr input, AttrType padValue,
88 ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
89 auto inputValues = input.tryGetValues<ElemType>();
90 if (failed(inputValues))
91 return nullptr;
92
93 auto oldShape = input.getType().getShape();
94
95 // Compute the output shape of the new value.
96 auto newShape =
97 llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
98 [](std::tuple<int64_t, int64_t, int64_t> pack) {
99 auto [old, low, high] = pack;
100 return old + low + high;
101 });
102
103 int64_t outputSize = computeProduct(newShape);
104
105 // Fully initialize the vector with the padding value.
106 // The non-padded area will then be copied.
107 SmallVector<ElemType> values(outputSize, padValue.getValue());
108
109 // Strides for input and output are used to transform between the indexing
110 // space of the input and output tensors.
111 SmallVector<int64_t> outputStrides = computeStrides(newShape);
112
113 // The contribution of the low padding to the offset in the output tensor.
114 // This is the starting position of the source tensor within the padding
115 // tensor.
116 int64_t startingOffset = linearize(padLow, outputStrides);
117
118 // Copy values from the input tensor to the corresponding sub-region
119 // of the output tensor.
120 for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
121 auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
122 values[outputIndex + startingOffset] = inputValue;
123 }
124
125 // Create an attribute for the folded value.
126 auto newType = input.getType().clone(newShape);
127 auto newAttr = DenseElementsAttr::get(newType, values);
128
129 Operation *constantOp =
130 rewriter.getContext()
131 ->getLoadedDialect<TensorDialect>()
132 ->materializeConstant(rewriter, newAttr, newType, loc);
133
134 return constantOp ? constantOp->getResult(0) : nullptr;
135}
136
137struct PadOpToConstant final : public OpRewritePattern<PadOp> {
138
139 PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
140 PatternBenefit benefit = 1)
141 : OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
142
143 LogicalResult matchAndRewrite(PadOp padTensorOp,
144 PatternRewriter &rewriter) const override {
145 if (padTensorOp.getNofold())
146 return rewriter.notifyMatchFailure(
147 padTensorOp, "refusing to fold nofold pad operation");
148
149 TypedValue<RankedTensorType> input = padTensorOp.getSource();
150 RankedTensorType resultType = padTensorOp.getResult().getType();
151
152 DenseElementsAttr inputAttr = nullptr;
153 if (!matchPattern(input, m_Constant(&inputAttr)))
154 return failure();
155
156 Value paddingValue = padTensorOp.getConstantPaddingValue();
157
158 // Extract the constant value used for padding or bail out.
159 Attribute paddingAttr = nullptr;
160 if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
161 return rewriter.notifyMatchFailure(padTensorOp,
162 "unable to get constant value");
163
164 // Try to extract the constant values of the low and high padding.
165 auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
166 auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
167
168 // If the padding cannot be extracted, bail out.
169 if (!lowPad || !highPad)
170 return rewriter.notifyMatchFailure(padTensorOp,
171 "unable to extract constant padding");
172
173 // We have a potential candidate, consult the control function to
174 // determine if the op should fold.
175 if (!controlFn(&padTensorOp.getSourceMutable()))
176 return rewriter.notifyMatchFailure(padTensorOp,
177 "not folding due to cost function");
178
179 Location loc = padTensorOp.getLoc();
180
181 // Try constant folding the supported cases of integer and float values.
182 Value newOp =
183 llvm::TypeSwitch<Attribute, Value>(paddingAttr)
184 .Case([&](FloatAttr floatAttr) {
185 return constantFoldPadOp<llvm::APFloat>(
186 rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
187 })
188 .Case([&](IntegerAttr integerAttr) {
189 return constantFoldPadOp<llvm::APInt>(
190 rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
191 })
192 .Default(nullptr);
193
194 if (!newOp)
195 return rewriter.notifyMatchFailure(padTensorOp,
196 "tensor type not supported");
197
198 if (newOp.getType() != resultType)
199 newOp = tensor::CastOp::create(rewriter, loc, resultType, newOp);
200
201 rewriter.replaceOp(padTensorOp, newOp);
202 return success();
203 }
204
205private:
206 ControlFoldFn controlFn;
207};
208
209} // namespace
210
212 RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
213 patterns.add<GenerateToConstant>(patterns.getContext());
214
215 patterns.add<PadOpToConstant>(patterns.getContext(), controlFn);
216}
return success()
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition FoldUtils.cpp:51
MLIRContext * getContext() const
Definition Builders.h:56
An attribute that represents a reference to a dense vector or tensor object.
FailureOr< iterator_range_impl< ElementIterator< T > > > tryGetValues() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
std::function< bool(OpOperand *)> ControlFoldFn
Definition Transforms.h:99
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns, const ControlFoldFn &controlFn)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
const FrozenRewritePatternSet & patterns
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...