MLIR 22.0.0git
InlineScalarOperands.cpp
Go to the documentation of this file.
1//===- InlineScalarOperands.cpp - Pass to inline scalar operands =============//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns/pass to inline scalar operands into a generic
10// operation. A scalar operand is an operand whose indexing map has a constant
11// rhs.
12//
13//===----------------------------------------------------------------------===//
14
16
21#include "mlir/IR/AffineExpr.h"
22#include "mlir/IR/AffineMap.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_LINALGINLINESCALAROPERANDSPASS
27#include "mlir/Dialect/Linalg/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31using namespace mlir::linalg;
32
33namespace {
34struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
35 using OpRewritePattern<GenericOp>::OpRewritePattern;
36 LogicalResult matchAndRewrite(GenericOp genericOp,
37 PatternRewriter &rewriter) const override {
38 if (!genericOp.hasPureTensorSemantics())
39 return failure();
40
41 SmallVector<size_t> scalarOperands;
42 SmallVector<AffineMap> newIndexingMaps;
43 SmallVector<Value> newOperands;
44 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
45 AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
46 if (genericOp.isDpsInput(opOperand) && map.isConstant()) {
47 scalarOperands.emplace_back(opOperand->getOperandNumber());
48 } else {
49 newIndexingMaps.emplace_back(map);
50 newOperands.emplace_back(opOperand->get());
51 }
52 }
53
54 if (scalarOperands.empty())
55 return failure();
56
57 for (OpOperand &opOperand : genericOp.getDpsInitsMutable())
58 newIndexingMaps.emplace_back(
59 genericOp.getMatchingIndexingMap(&opOperand));
60
61 Location loc = genericOp->getLoc();
62 SmallVector<Value> outputOperands = genericOp.getOutputs();
63 auto newOp = GenericOp::create(rewriter, loc, genericOp->getResultTypes(),
64 newOperands, outputOperands, newIndexingMaps,
65 genericOp.getIteratorTypesArray());
66 rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(),
67 newOp.getRegion().begin());
68
69 Block *body = newOp.getBody();
70 PatternRewriter::InsertionGuard guard(rewriter);
71 rewriter.setInsertionPointToStart(body);
72
73 for (auto idx : llvm::reverse(scalarOperands)) {
74 OpOperand *opOperand = genericOp.getDpsInputOperand(idx);
75 AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
76 SmallVector<int64_t> indices = map.getConstantResults();
77 SmallVector<Value> indicesValues;
78 for (auto idx : indices)
79 indicesValues.emplace_back(
80 arith::ConstantIndexOp::create(rewriter, loc, idx));
81 Value scalarValue = opOperand->get();
82 if (isa<RankedTensorType>(scalarValue.getType())) {
83 scalarValue = tensor::ExtractOp::create(rewriter, loc, scalarValue,
84 indicesValues);
85 }
86 body->getArgument(idx).replaceAllUsesWith(scalarValue);
87 body->eraseArgument(idx);
88 }
89
90 rewriter.replaceOp(genericOp, newOp->getResults());
91 return success();
92 }
93};
94} // namespace
95
96/// Patterns that are used to inline constant operands into linalg generic
97/// ops.
100 auto *context = patterns.getContext();
101 patterns.add<InlineScalarOperands>(context);
102}
103
104namespace {
105/// Pass that removes unit-extent dims within generic ops.
106struct LinalgInlineScalarOperandsPass
108 LinalgInlineScalarOperandsPass> {
110 LinalgInlineScalarOperandsPass>::LinalgInlineScalarOperandsPassBase;
111 void runOnOperation() override {
112 Operation *op = getOperation();
113 MLIRContext &ctx = getContext();
116 (void)applyPatternsGreedily(op, std::move(patterns));
117 }
118};
119} // namespace
return success()
b getContext())
bool isConstant() const
Returns true if this affine map has only constant results.
SmallVector< int64_t > getConstantResults() const
Returns the constant results of this map.
BlockArgument getArgument(unsigned i)
Definition Block.h:129
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition Block.cpp:193
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Type getType() const
Return the type of this value.
Definition Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition Value.h:149
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns)
Patterns that are used to inline constant operands into linalg generic ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...