18#include "llvm/ADT/SmallVectorExtras.h"
38template <
typename ConcreteType>
42 std::optional<APInt> apInt;
43 std::optional<APFloat> apFloat;
45 struct APIntOrFloatArray {
46 SmallVector<APInt> apInts;
47 SmallVector<APFloat> apFloats;
49 using RegionComputationFn =
50 std::function<APIntOrFloat(
const APIntOrFloatArray &)>;
52 FoldConstantBase(MLIRContext *context,
const ControlFusionFn &controlFn,
53 PatternBenefit benefit = 1)
54 : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
55 controlFn(controlFn) {}
57 LogicalResult matchAndRewrite(LinalgOp linalgOp,
58 PatternRewriter &rewriter)
const override {
60 if (!linalgOp.hasPureTensorSemantics())
64 if (linalgOp.getNumDpsInits() != 1)
67 auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
70 if (!outputType || !outputType.hasStaticShape())
73 if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
74 return isa<ShapedType>(input.getType());
79 auto getOperandElementType = [](Value value) {
80 return cast<ShapedType>(value.getType()).getElementType();
83 llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
87 auto elementType = outputType.getElementType();
88 if (!elementType.isIntOrFloat())
96 if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
97 [](AffineMap map) { return map.isPermutation(); }))
100 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
101 if (linalgOp.payloadUsesValueFromOperand(&operand))
106 if (!
static_cast<const ConcreteType *
>(
this)->matchIndexingMaps(linalgOp))
111 RegionComputationFn computeFn =
112 static_cast<const ConcreteType *
>(
this)->getRegionComputeFn(linalgOp);
117 int numInputs = linalgOp.getNumDpsInputs();
118 SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
119 for (
const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
127 for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
128 if (!controlFn(operand))
132 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
133 int64_t numElements = outputType.getNumElements();
138 SmallVector<APInt> intOutputValues;
139 SmallVector<APFloat> fpOutputValues;
140 if (isa<FloatType>(elementType))
141 fpOutputValues.resize(numElements, APFloat(0.f));
143 intOutputValues.resize(numElements);
146 auto getDimPositions = [](AffineMap map) {
147 SmallVector<unsigned> dims;
150 dims.push_back(cast<AffineDimExpr>(
result).getPosition());
155 SmallVector<SmallVector<unsigned>> inputDims;
156 for (
int i = 0; i < numInputs; ++i)
157 inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
158 auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
159 auto outputShape = outputType.getShape();
163 SmallVector<uint64_t>
indices(loopBounds.size(), 0);
164 SmallVector<uint64_t> dstIndices(loopBounds.size(), 0);
165 SmallVector<SmallVector<uint64_t>> srcIndices(
166 numInputs, SmallVector<uint64_t>(loopBounds.size(), 0));
167 SmallVector<uint64_t> srcLinearIndices(numInputs, 0);
168 uint64_t dstLinearIndex = 0;
172 APIntOrFloatArray computeFnInputs;
175 llvm::map_to_vector<4>(linalgOp.getDpsInputs(), [](Value value) {
176 return cast<ShapedType>(value.getType()).getShape();
182 auto computeRemappedLinearIndex = [&](
int linearIndex) {
183 int totalCount = linearIndex;
184 for (
int dim = loopBounds.size() - 1; dim >= 0; --dim) {
185 indices[dim] = totalCount % loopBounds[dim];
186 totalCount /= loopBounds[dim];
189 for (
int dim = loopBounds.size() - 1; dim >= 0; --dim) {
190 for (
int i = 0; i < numInputs; ++i)
191 srcIndices[i][dim] =
indices[inputDims[i][dim]];
192 dstIndices[dim] =
indices[outputDims[dim]];
195 dstLinearIndex = dstIndices.front();
196 for (
int i = 0; i < numInputs; ++i)
197 srcLinearIndices[i] = srcIndices[i].front();
199 for (
int dim = 1; dim < outputType.getRank(); ++dim) {
200 dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
201 for (
int i = 0; i < numInputs; ++i)
202 srcLinearIndices[i] =
203 srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim];
207 bool isFloat = isa<FloatType>(elementType);
209 SmallVector<DenseElementsAttr::iterator_range<APFloat>> inFpRanges;
210 for (
int i = 0; i < numInputs; ++i)
211 inFpRanges.push_back(inputValues[i].getValues<APFloat>());
213 computeFnInputs.apFloats.resize(numInputs, APFloat(0.f));
218 for (
int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
219 computeRemappedLinearIndex(linearIndex);
222 for (
int i = 0; i < numInputs; ++i)
223 computeFnInputs.apFloats[i] = inFpRanges[i][srcLinearIndices[i]];
227 fpOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apFloat;
230 SmallVector<DenseElementsAttr::iterator_range<APInt>> inIntRanges;
231 for (
int i = 0; i < numInputs; ++i)
232 inIntRanges.push_back(inputValues[i].getValues<APInt>());
234 computeFnInputs.apInts.resize(numInputs);
239 for (
int linearIndex = 0; linearIndex < numElements; ++linearIndex) {
240 computeRemappedLinearIndex(linearIndex);
243 for (
int i = 0; i < numInputs; ++i)
244 computeFnInputs.apInts[i] = inIntRanges[i][srcLinearIndices[i]];
248 intOutputValues[dstLinearIndex] = *computeFn(computeFnInputs).apInt;
252 DenseElementsAttr outputAttr =
254 : DenseElementsAttr::
get(outputType, intOutputValues);
266struct FoldConstantTranspose :
public FoldConstantBase<FoldConstantTranspose> {
268 using FoldConstantBase::FoldConstantBase;
270 bool matchIndexingMaps(LinalgOp linalgOp)
const {
272 return linalgOp.getIndexingMapsArray().size() == 2;
275 RegionComputationFn getRegionComputeFn(LinalgOp linalgOp)
const {
277 Block &body = linalgOp->getRegion(0).front();
278 if (!llvm::hasSingleElement(body))
280 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
285 for (Value yieldVal : yieldOp.getValues()) {
286 auto yieldArg = dyn_cast<BlockArgument>(yieldVal);
287 if (!yieldArg || yieldArg.getOwner() != &body)
289 if (yieldArg.getArgNumber() != 0)
294 return [](
const APIntOrFloatArray &inputs) {
295 if (inputs.apFloats.empty())
296 return APIntOrFloat{inputs.apInts.front(), std::nullopt};
297 return APIntOrFloat{std::nullopt, inputs.apFloats.front()};
308 patterns.insert<FoldConstantTranspose>(context, controlFn);
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
Operation * getTerminator()
Get the terminator operation of this block.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
MLIRContext is the top-level object for a collection of MLIR operations.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
std::function< bool(OpOperand *fusedOperand)> ControlFusionFn
Function type which is used to control when to stop fusion.
void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn)
Patterns to constant fold Linalg operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...