MLIR  22.0.0git
ConstantFold.cpp
Go to the documentation of this file.
1 //===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements constant folding on Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Support/LLVM.h"
18 #include <optional>
19 
20 using namespace mlir;
21 using namespace mlir::linalg;
22 
23 namespace {
24 /// Base class for constant folding linalg structured ops with N inputs, 1
25 /// output, and permutation indexing maps.
26 ///
27 /// `ConcreteType` should provide methods with signatures
28 ///
29 /// ```c++
30 /// bool matchIndexingMaps(LinalgOp linalgOp) const;
31 /// RegionComputationFn getRegionComputeFn(LinalgOp) const;
32 /// ```
33 ///
34 /// The latter inspects the region and returns the computation inside as a
35 /// functor. The functor will be invoked with constant elements for all inputs
36 /// and should return the corresponding computed constant element for output.
37 template <typename ConcreteType>
38 class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
39 public:
40  struct APIntOrFloat {
41  std::optional<APInt> apInt;
42  std::optional<APFloat> apFloat;
43  };
44  struct APIntOrFloatArray {
45  SmallVector<APInt> apInts;
46  SmallVector<APFloat> apFloats;
47  };
48  using RegionComputationFn =
49  std::function<APIntOrFloat(const APIntOrFloatArray &)>;
50 
51  FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
52  PatternBenefit benefit = 1)
53  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
54  controlFn(controlFn) {}
55 
56  LogicalResult matchAndRewrite(LinalgOp linalgOp,
57  PatternRewriter &rewriter) const override {
58  // Mixed and buffer sematics aren't supported.
59  if (!linalgOp.hasPureTensorSemantics())
60  return failure();
61 
62  // Only support ops generating one output for now.
63  if (linalgOp.getNumDpsInits() != 1)
64  return failure();
65 
66  auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
67  // Require the output types to be static given that we are generating
68  // constants.
69  if (!outputType || !outputType.hasStaticShape())
70  return failure();
71 
72  if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
73  return isa<ShapedType>(input.getType());
74  }))
75  return failure();
76 
77  // Make sure all element types are the same.
78  auto getOperandElementType = [](Value value) {
79  return cast<ShapedType>(value.getType()).getElementType();
80  };
81  if (!llvm::all_equal(
82  llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
83  return failure();
84 
85  // We can only handle the case where we have int/float elements.
86  auto elementType = outputType.getElementType();
87  if (!elementType.isIntOrFloat())
88  return failure();
89 
90  // Require all indexing maps to be permutations for now. This is common and
91  // it simplifies input/output access greatly: we can do the data shuffling
92  // entirely in the compiler, without needing to turn all indices into
93  // Values, and then do affine apply on them, and then match back the
94  // constant again.
95  if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
96  [](AffineMap map) { return map.isPermutation(); }))
97  return failure();
98 
99  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
100  if (linalgOp.payloadUsesValueFromOperand(&operand))
101  return failure();
102  }
103 
104  // Further check the indexing maps are okay for the ConcreteType.
105  if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
106  return failure();
107 
108  // Defer to the concrete type to check the region and discover the
109  // computation inside.
110  RegionComputationFn computeFn =
111  static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
112  if (!computeFn)
113  return failure();
114 
115  // All inputs should be constants.
116  int numInputs = linalgOp.getNumDpsInputs();
117  SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
118  for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
119  if (!matchPattern(en.value()->get(),
120  m_Constant(&inputValues[en.index()])))
121  return failure();
122  }
123 
124  // Identified this as a potential candidate for folding. Now check the
125  // policy to see whether we are allowed to proceed.
126  for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
127  if (!controlFn(operand))
128  return failure();
129  }
130 
131  SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
132  int64_t numElements = outputType.getNumElements();
133 
134  // Use APInt/APFloat instead of Attribute here for constructing the output.
135  // This helps to avoid blowing up compiler memory usage: Attributes would
136  // unify the following cases but they have lifetime as the MLIRContext.
137  SmallVector<APInt> intOutputValues;
138  SmallVector<APFloat> fpOutputValues;
139  if (isa<FloatType>(elementType))
140  fpOutputValues.resize(numElements, APFloat(0.f));
141  else
142  intOutputValues.resize(numElements);
143 
144  // Return the constant dim positions from the given permutation map.
145  auto getDimPositions = [](AffineMap map) {
147  dims.reserve(map.getNumResults());
148  for (AffineExpr result : map.getResults()) {
149  dims.push_back(cast<AffineDimExpr>(result).getPosition());
150  }
151  return dims;
152  };
153 
155  for (int i = 0; i < numInputs; ++i)
156  inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
157  auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
158  auto outputShape = outputType.getShape();
159 
160  // Allocate small vectors for index delinearization. Initial values do not
161  // matter here as they will be overwritten later.
162  SmallVector<uint64_t> indices(loopBounds.size(), 0);
163  SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
165  numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
166  SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
167  uint64_t dstLinearIndex = 0;
168 
169  // Allocate spaces for compute function inputs. Initial values do not matter
170  // here as they will be overwritten later.
171  APIntOrFloatArray computeFnInputs;
172 
173  auto inputShapes = llvm::to_vector<4>(
174  llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
175  return cast<ShapedType>(value.getType()).getShape();
176  }));
177 
178  // Given a `linearIndex`, remap it to a linear index to access linalg op
179  // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
180  // `srcLinearIndices`, `dstLinearIndex` in place.
181  auto computeRemappedLinearIndex = [&](int linearIndex) {
182  int totalCount = linearIndex;
183  for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
184  indices[dim] = totalCount % loopBounds[dim];
185  totalCount /= loopBounds[dim];
186  }
187 
188  for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
189  for (int i = 0; i < numInputs; ++i)
190  srcIndices[i][dim] = indices[inputDims[i][dim]];
191  dstIndices[dim] = indices[outputDims[dim]];
192  }
193 
194  dstLinearIndex = dstIndices.front();
195  for (int i = 0; i < numInputs; ++i)
196  srcLinearIndices[i] = srcIndices[i].front();
197 
198  for (int dim = 1; dim < outputType.getRank(); ++dim) {
199  dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
200  for (int i = 0; i < numInputs; ++i)
201  srcLinearIndices[i] =
202  srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
203  }
204  };
205 
206  bool isFloat = isa<FloatType>(elementType);
207  if (isFloat) {
209  for (int i = 0; i < numInputs; ++i)
210  inFpRanges.push_back(inputValues[i].getValues<APFloat>());
211 
212  computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
213 
214  // Transpose the input constant. Because we don't know its rank in
215  // advance, we need to loop over the range [0, element count) and
216  // delinearize the index.
217  for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
218  computeRemappedLinearIndex(linearIndex);
219 
220  // Collect constant elements for all inputs at this loop iteration.
221  for (int i = 0; i < numInputs; ++i)
222  computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
223 
224  // Invoke the computation to get the corresponding constant output
225  // element.
226  fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
227  }
228  } else {
230  for (int i = 0; i < numInputs; ++i)
231  inIntRanges.push_back(inputValues[i].getValues<APInt>());
232 
233  computeFnInputs.apInts.resize(numInputs);
234 
235  // Transpose the input constant. Because we don't know its rank in
236  // advance, we need to loop over the range [0, element count) and
237  // delinearize the index.
238  for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
239  computeRemappedLinearIndex(linearIndex);
240 
241  // Collect constant elements for all inputs at this loop iteration.
242  for (int i = 0; i < numInputs; ++i)
243  computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
244 
245  // Invoke the computation to get the corresponding constant output
246  // element.
247  intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
248  }
249  }
250 
251  DenseElementsAttr outputAttr =
252  isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
253  : DenseElementsAttr::get(outputType, intOutputValues);
254 
255  rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
256  return success();
257  }
258 
259 private:
260  ControlFusionFn controlFn;
261 };
262 
263 // Folds linalg.transpose (and linalg.generic ops that are actually transposes)
264 // on constant values.
265 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
266 
267  using FoldConstantBase::FoldConstantBase;
268 
269  bool matchIndexingMaps(LinalgOp linalgOp) const {
270  // We should have one input and one output.
271  return linalgOp.getIndexingMapsArray().size() == 2;
272  }
273 
274  RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
275  // Make sure the region only contains a yield op.
276  Block &body = linalgOp->getRegion(0).front();
277  if (!llvm::hasSingleElement(body))
278  return nullptr;
279  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
280  if (!yieldOp)
281  return nullptr;
282 
283  // The yield op should return the block argument corresponds to the input.
284  for (Value yieldVal : yieldOp.getValues()) {
285  auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
286  if (!yieldArg || yieldArg.getOwner() != &body)
287  return nullptr;
288  if (yieldArg.getArgNumber() != 0)
289  return nullptr;
290  }
291 
292  // No computation; just return the orginal value.
293  return [](const APIntOrFloatArray &inputs) {
294  if (inputs.apFloats.empty())
295  return APIntOrFloat{inputs.apInts.front(), std::nullopt};
296  return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
297  };
298  }
299 
300  ControlFusionFn controlFn;
301 };
302 } // namespace
303 
305  RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
306  MLIRContext *context = patterns.getContext();
307  patterns.insert<FoldConstantTranspose>(context, controlFn);
308 }
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
Operation & front()
Definition: Block.h:153
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents an operand of an operation.
Definition: Value.h:257
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:1902
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330