MLIR  17.0.0git
ConstantFold.cpp
Go to the documentation of this file.
1 //===- ConstantFold.cpp - Implementation of constant folding on Linalg ops ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements constant folding on Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Support/LLVM.h"
20 #include <optional>
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
25 namespace {
26 /// Base class for constant folding linalg.generic ops with N inputs, 1 output,
27 /// and permutation indexing maps.
28 ///
29 /// `ConcreteType` should provide methods with signatures
30 ///
31 /// ```c++
32 /// bool matchIndexingMaps(GenericOp genericOp) const;
33 /// RegionComputationFn getRegionComputeFn(GenericOp) const;
34 /// ```
35 ///
36 /// The latter inspects the region and returns the computation inside as a
37 /// functor. The functor will be invoked with constant elements for all inputs
38 /// and should return the corresponding computed constant element for output.
39 template <typename ConcreteType>
40 class FoldConstantBase : public OpRewritePattern<GenericOp> {
41 public:
42  struct APIntOrFloat {
43  std::optional<APInt> apInt;
44  std::optional<APFloat> apFloat;
45  };
46  struct APIntOrFloatArray {
47  SmallVector<APInt> apInts;
48  SmallVector<APFloat> apFloats;
49  };
50  using RegionComputationFn =
51  std::function<APIntOrFloat(const APIntOrFloatArray &)>;
52 
53  FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
54  PatternBenefit benefit = 1)
55  : OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
56 
57  LogicalResult matchAndRewrite(GenericOp genericOp,
58  PatternRewriter &rewriter) const override {
59  // Mixed and buffer sematics aren't supported.
60  if (!genericOp.hasTensorSemantics())
61  return failure();
62 
63  // Only support ops generating one output for now.
64  if (genericOp.getNumDpsInits() != 1)
65  return failure();
66 
67  auto outputType = genericOp.getResultTypes().front().dyn_cast<ShapedType>();
68  // Require the output types to be static given that we are generating
69  // constants.
70  if (!outputType || !outputType.hasStaticShape())
71  return failure();
72 
73  if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
74  return input.getType().isa<ShapedType>();
75  }))
76  return failure();
77 
78  // Make sure all element types are the same.
79  auto getOperandElementType = [](Value value) {
80  return value.getType().cast<ShapedType>().getElementType();
81  };
82  if (!llvm::all_equal(
83  llvm::map_range(genericOp->getOperands(), getOperandElementType)))
84  return failure();
85 
86  // We can only handle the case where we have int/float elements.
87  auto elementType = outputType.getElementType();
88  if (!elementType.isIntOrFloat())
89  return failure();
90 
91  // Require all indexing maps to be permutations for now. This is common and
92  // it simplifies input/output access greatly: we can do the data shuffling
93  // entirely in the compiler, without needing to turn all indices into
94  // Values, and then do affine apply on them, and then match back the
95  // constant again.
96  if (!llvm::all_of(genericOp.getIndexingMapsArray(),
97  [](AffineMap map) { return map.isPermutation(); }))
98  return failure();
99 
100  for (OpOperand *operand : genericOp.getDpsInitOperands()) {
101  if (genericOp.payloadUsesValueFromOperand(operand))
102  return failure();
103  }
104 
105  // Further check the indexing maps are okay for the ConcreteType.
106  if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
107  return failure();
108 
109  // Defer to the concrete type to check the region and discover the
110  // computation inside.
111  RegionComputationFn computeFn =
112  static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
113  if (!computeFn)
114  return failure();
115 
116  // All inputs should be constants.
117  int numInputs = genericOp.getNumDpsInputs();
118  SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
119  for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
120  if (!matchPattern(en.value()->get(),
121  m_Constant(&inputValues[en.index()])))
122  return failure();
123  }
124 
125  // Identified this as a potential candidate for folding. Now check the
126  // policy to see whether we are allowed to proceed.
127  for (OpOperand *operand : genericOp.getDpsInputOperands()) {
128  if (!controlFn(operand))
129  return failure();
130  }
131 
132  auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
133  SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
134  int64_t numElements = outputType.getNumElements();
135 
136  // Use APInt/APFloat instead of Attribute here for constructing the output.
137  // This helps to avoid blowing up compiler memory usage: Attributes would
138  // unify the following cases but they have lifetime as the MLIRContext.
139  SmallVector<APInt> intOutputValues;
140  SmallVector<APFloat> fpOutputValues;
141  if (elementType.template isa<FloatType>())
142  fpOutputValues.resize(numElements, APFloat(0.f));
143  else
144  intOutputValues.resize(numElements);
145 
146  // Return the constant dim positions from the given permutation map.
147  auto getDimPositions = [](AffineMap map) {
149  dims.reserve(map.getNumResults());
150  for (AffineExpr result : map.getResults()) {
151  dims.push_back(result.cast<AffineDimExpr>().getPosition());
152  }
153  return dims;
154  };
155 
157  for (int i = 0; i < numInputs; ++i)
158  inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
159  auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
160  auto outputShape = outputType.getShape();
161 
162  // Allocate small vectors for index delinearization. Initial values do not
163  // matter here as they will be overwritten later.
164  SmallVector<uint64_t> indices(loopBounds.size(), 0);
165  SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
167  numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
168  SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
169  uint64_t dstLinearIndex = 0;
170 
171  // Allocate spaces for compute function inputs. Initial values do not matter
172  // here as they will be overwritten later.
173  APIntOrFloatArray computeFnInputs;
174 
175  auto inputShapes = llvm::to_vector<4>(
176  llvm::map_range(genericOp.getInputs(), [](Value value) {
177  return value.getType().cast<ShapedType>().getShape();
178  }));
179 
180  // Given a `linearIndex`, remap it to a linear index to access linalg op
181  // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`,
182  // `srcLinearIndices`, `dstLinearIndex` in place.
183  auto computeRemappedLinearIndex = [&](int linearIndex) {
184  int totalCount = linearIndex;
185  for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
186  indices[dim] = totalCount % loopBounds[dim];
187  totalCount /= loopBounds[dim];
188  }
189 
190  for (int dim = loopBounds.size() - 1; dim >= 0; --dim) {
191  for (int i = 0; i < numInputs; ++i)
192  srcIndices[i][dim] = indices[inputDims[i][dim]];
193  dstIndices[dim] = indices[outputDims[dim]];
194  }
195 
196  dstLinearIndex = dstIndices.front();
197  for (int i = 0; i < numInputs; ++i)
198  srcLinearIndices[i] = srcIndices[i].front();
199 
200  for (int dim = 1; dim < outputType.getRank(); ++dim) {
201  dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
202  for (int i = 0; i < numInputs; ++i)
203  srcLinearIndices[i] =
204  srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
205  }
206  };
207 
208  bool isFloat = elementType.isa<FloatType>();
209  if (isFloat) {
211  for (int i = 0; i < numInputs; ++i)
212  inFpRanges.push_back(inputValues[i].getValues<APFloat>());
213 
214  computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
215 
216  // Transpose the input constant. Because we don't know its rank in
217  // advance, we need to loop over the range [0, element count) and
218  // delinearize the index.
219  for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
220  computeRemappedLinearIndex(linearIndex);
221 
222  // Collect constant elements for all inputs at this loop iteration.
223  for (int i = 0; i < numInputs; ++i)
224  computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
225 
226  // Invoke the computation to get the corresponding constant output
227  // element.
228  fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
229  }
230  } else {
232  for (int i = 0; i < numInputs; ++i)
233  inIntRanges.push_back(inputValues[i].getValues<APInt>());
234 
235  computeFnInputs.apInts.resize(numInputs);
236 
237  // Transpose the input constant. Because we don't know its rank in
238  // advance, we need to loop over the range [0, element count) and
239  // delinearize the index.
240  for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
241  computeRemappedLinearIndex(linearIndex);
242 
243  // Collect constant elements for all inputs at this loop iteration.
244  for (int i = 0; i < numInputs; ++i)
245  computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
246 
247  // Invoke the computation to get the corresponding constant output
248  // element.
249  intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
250  }
251  }
252 
253  DenseElementsAttr outputAttr =
254  isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
255  : DenseElementsAttr::get(outputType, intOutputValues);
256 
257  rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
258  return success();
259  }
260 
261 private:
262  ControlFusionFn controlFn;
263 };
264 
265 // Folds linalg.generic ops that are actually transposes on constant values.
266 struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
267  using FoldConstantBase::FoldConstantBase;
268 
269  bool matchIndexingMaps(GenericOp genericOp) const {
270  // We should have one input and one output.
271  return genericOp.getIndexingMapsArray().size() == 2;
272  }
273 
274  RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
275  // Make sure the region only contains a yield op.
276  Block &body = genericOp.getRegion().front();
277  if (!llvm::hasSingleElement(body))
278  return nullptr;
279  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
280  if (!yieldOp)
281  return nullptr;
282 
283  // The yield op should return the block argument corresponds to the input.
284  for (Value yieldVal : yieldOp.getValues()) {
285  auto yieldArg = yieldVal.dyn_cast<BlockArgument>();
286  if (!yieldArg || yieldArg.getOwner() != &body)
287  return nullptr;
288  if (yieldArg.getArgNumber() != 0)
289  return nullptr;
290  }
291 
292  // No computation; just return the orginal value.
293  return [](const APIntOrFloatArray &inputs) {
294  if (inputs.apFloats.empty())
295  return APIntOrFloat{inputs.apInts.front(), std::nullopt};
296  return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
297  };
298  }
299 
300  ControlFusionFn controlFn;
301 };
302 } // namespace
303 
305  RewritePatternSet &patterns, const ControlFusionFn &controlFn) {
306  MLIRContext *context = patterns.getContext();
307  patterns.insert<FoldConstantTranspose>(context, controlFn);
308 }
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
unsigned getPosition() const
Definition: AffineExpr.cpp:325
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:332
unsigned getNumResults() const
Definition: AffineMap.cpp:327
This class represents an argument of a Block.
Definition: Value.h:304
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
Operation & front()
Definition: Block.h:142
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class represents an operand of an operation.
Definition: Value.h:255
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:621
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
Definition: Transforms.h:79
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:322
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:248
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357