MLIR 22.0.0git
ConstantFold.cpp
Go to the documentation of this file.
1//===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements constant folding on Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Matchers.h"
17#include "mlir/Support/LLVM.h"
18#include <optional>
19
20using namespace mlir;
21using namespace mlir::linalg;
22
23namespace {
24/// Base class for constant folding linalg structured ops with N inputs, 1
25/// output, and permutation indexing maps.
26///
27/// `ConcreteType` should provide methods with signatures
28///
29/// ```c++
30/// bool matchIndexingMaps(LinalgOp linalgOp) const;
31/// RegionComputationFn getRegionComputeFn(LinalgOp) const;
32/// ```
33///
34/// The latter inspects the region and returns the computation inside as a
35/// functor. The functor will be invoked with constant elements for all inputs
36/// and should return the corresponding computed constant element for output.
37template <typename ConcreteType>
38class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
39public:
40 struct APIntOrFloat {
41 std::optional<APInt> apInt;
42 std::optional<APFloat> apFloat;
43 };
44 struct APIntOrFloatArray {
45 SmallVector<APInt> apInts;
46 SmallVector<APFloat> apFloats;
47 };
48 using RegionComputationFn =
49 std::function<APIntOrFloat(const APIntOrFloatArray &)>;
50
51 FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
52 PatternBenefit benefit = 1)
53 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
54 controlFn(controlFn) {}
55
56 LogicalResult matchAndRewrite(LinalgOp linalgOp,
57 PatternRewriter &rewriter) const override {
58 // Mixed and buffer sematics aren't supported.
59 if (!linalgOp.hasPureTensorSemantics())
60 return failure();
61
62 // Only support ops generating one output for now.
63 if (linalgOp.getNumDpsInits() != 1)
64 return failure();
65
66 auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
67 // Require the output types to be static given that we are generating
68 // constants.
69 if (!outputType || !outputType.hasStaticShape())
70 return failure();
71
72 if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
73 return isa<ShapedType>(input.getType());
74 }))
75 return failure();
76
77 // Make sure all element types are the same.
78 auto getOperandElementType = [](Value value) {
79 return cast<ShapedType>(value.getType()).getElementType();
80 };
81 if (!llvm::all_equal(
82 llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
83 return failure();
84
85 // We can only handle the case where we have int/float elements.
86 auto elementType = outputType.getElementType();
87 if (!elementType.isIntOrFloat())
88 return failure();
89
90 // Require all indexing maps to be permutations for now. This is common and
91 // it simplifies input/output access greatly: we can do the data shuffling
92 // entirely in the compiler, without needing to turn all indices into
93 // Values, and then do affine apply on them, and then match back the
94 // constant again.
95 if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
96 [](AffineMap map) { return map.isPermutation(); }))
97 return failure();
98
99 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
100 if (linalgOp.payloadUsesValueFromOperand(&operand))
101 return failure();
102 }
103
104 // Further check the indexing maps are okay for the ConcreteType.
105 if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
106 return failure();
107
108 // Defer to the concrete type to check the region and discover the
109 // computation inside.
110 RegionComputationFn computeFn =
111 static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
112 if (!computeFn)
113 return failure();
114
115 // All inputs should be constants.
116 int numInputs = linalgOp.getNumDpsInputs();
117 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
118 for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
119 if (!matchPattern(en.value()->get(),
120 m_Constant(&inputValues[en.index()])))
121 return failure();
122 }
123
124 // Identified this as a potential candidate for folding. Now check the
125 // policy to see whether we are allowed to proceed.
126 for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
127 if (!controlFn(operand))
128 return failure();
129 }
130
131 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
132 int64_t numElements = outputType.getNumElements();
133
134 // Use APInt/APFloat instead of Attribute here for constructing the output.
135 // This helps to avoid blowing up compiler memory usage: Attributes would
136 // unify the following cases but they have lifetime as the MLIRContext.
137 SmallVector<APInt> intOutputValues;
138 SmallVector<APFloat> fpOutputValues;
139 if (isa<FloatType>(elementType))
140 fpOutputValues.resize(numElements, APFloat(0.f));
141 else
142 intOutputValues.resize(numElements);
143
144 // Return the constant dim positions from the given permutation map.
145 auto getDimPositions = [](AffineMap map) {
146 SmallVector<unsigned> dims;
147 dims.reserve(map.getNumResults());
148 for (AffineExpr result : map.getResults()) {
149 dims.push_back(cast<AffineDimExpr>(result).getPosition());
150 }
151 return dims;
152 };
153
154 SmallVector<SmallVector<unsigned>> inputDims;
155 for (int i = 0; i < numInputs; ++i)
156 inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
157 auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
158 auto outputShape = outputType.getShape();
159
160 // Allocate small vectors for index delinearization. Initial values do not
161 // matter here as they will be overwritten later.
162 SmallVector<uint64_t> indices(loopBounds.size(), 0);
163 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
164 SmallVector<SmallVector<uint64_t>> srcIndices(
165 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
166 SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
167 uint64_t dstLinearIndex = 0;
168
169 // Allocate spaces for compute function inputs. Initial values do not matter
170 // here as they will be overwritten later.
171 APIntOrFloatArray computeFnInputs;
172
173 auto inputShapes = llvm::to_vector<4>(
174 llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
175 return cast<ShapedType>(value.getType()).getShape();
176 }));
177
178 // Given a `linearIndex`, remap it to a linear index to access linalg op
179 // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
180 // `srcLinearIndices`, `dstLinearIndex` in place.
181 auto computeRemappedLinearIndex = [&](int linearIndex) {
182 int totalCount = linearIndex;
183 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
184 indices[dim] = totalCount % loopBounds[dim];
185 totalCount /= loopBounds[dim];
186 }
187
188 for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
189 for (int i = 0; i < numInputs; ++i)
190 srcIndices[i][dim] = indices[inputDims[i][dim]];
191 dstIndices[dim] = indices[outputDims[dim]];
192 }
193
194 dstLinearIndex = dstIndices.front();
195 for (int i = 0; i < numInputs; ++i)
196 srcLinearIndices[i] = srcIndices[i].front();
197
198 for (int dim = 1; dim < outputType.getRank(); ++dim) {
199 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
200 for (int i = 0; i < numInputs; ++i)
201 srcLinearIndices[i] =
202 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
203 }
204 };
205
206 bool isFloat = isa<FloatType>(elementType);
207 if (isFloat) {
208 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
209 for (int i = 0; i < numInputs; ++i)
210 inFpRanges.push_back(inputValues[i].getValues<APFloat>());
211
212 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
213
214 // Transpose the input constant. Because we don't know its rank in
215 // advance, we need to loop over the range [0, element count) and
216 // delinearize the index.
217 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
218 computeRemappedLinearIndex(linearIndex);
219
220 // Collect constant elements for all inputs at this loop iteration.
221 for (int i = 0; i < numInputs; ++i)
222 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
223
224 // Invoke the computation to get the corresponding constant output
225 // element.
226 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
227 }
228 } else {
229 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
230 for (int i = 0; i < numInputs; ++i)
231 inIntRanges.push_back(inputValues[i].getValues<APInt>());
232
233 computeFnInputs.apInts.resize(numInputs);
234
235 // Transpose the input constant. Because we don't know its rank in
236 // advance, we need to loop over the range [0, element count) and
237 // delinearize the index.
238 for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
239 computeRemappedLinearIndex(linearIndex);
240
241 // Collect constant elements for all inputs at this loop iteration.
242 for (int i = 0; i < numInputs; ++i)
243 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
244
245 // Invoke the computation to get the corresponding constant output
246 // element.
247 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
248 }
249 }
250
251 DenseElementsAttr outputAttr =
252 isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
253 : DenseElementsAttr::get(outputType, intOutputValues);
254
255 rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
256 return success();
257 }
258
259private:
260 ControlFusionFn controlFn;
261};
262
263// Folds linalg.transpose (and linalg.generic ops that are actually transposes)
264// on constant values.
265struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
266
267 using FoldConstantBase::FoldConstantBase;
268
269 bool matchIndexingMaps(LinalgOp linalgOp) const {
270 // We should have one input and one output.
271 return linalgOp.getIndexingMapsArray().size() == 2;
272 }
273
274 RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
275 // Make sure the region only contains a yield op.
276 Block &body = linalgOp->getRegion(0).front();
277 if (!llvm::hasSingleElement(body))
278 return nullptr;
279 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
280 if (!yieldOp)
281 return nullptr;
282
283 // The yield op should return the block argument corresponds to the input.
284 for (Value yieldVal : yieldOp.getValues()) {
285 auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
286 if (!yieldArg || yieldArg.getOwner() != &body)
287 return nullptr;
288 if (yieldArg.getArgNumber() != 0)
289 return nullptr;
290 }
291
292 // No computation; just return the orginal value.
293 return [](const APIntOrFloatArray &inputs) {
294 if (inputs.apFloats.empty())
295 return APIntOrFloat{inputs.apInts.front(), std::nullopt};
296 return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
297 };
298 }
299
300 ControlFusionFn controlFn;
301};
302} // namespace
303
305 RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
306 MLIRContext *context = patterns.getContext();
307 patterns.insert<FoldConstantTranspose>(context, controlFn);
308}
return success()
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...