26#define GEN_PASS_DEF_LINALGINLINESCALAROPERANDSPASS
27#include "mlir/Dialect/Linalg/Passes.h.inc"
35 using OpRewritePattern<GenericOp>::OpRewritePattern;
36 LogicalResult matchAndRewrite(GenericOp genericOp,
37 PatternRewriter &rewriter)
const override {
38 if (!genericOp.hasPureTensorSemantics())
41 SmallVector<size_t> scalarOperands;
42 SmallVector<AffineMap> newIndexingMaps;
43 SmallVector<Value> newOperands;
44 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
45 AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
46 if (genericOp.isDpsInput(opOperand) && map.
isConstant()) {
47 scalarOperands.emplace_back(opOperand->getOperandNumber());
49 newIndexingMaps.emplace_back(map);
50 newOperands.emplace_back(opOperand->get());
54 if (scalarOperands.empty())
57 for (OpOperand &opOperand : genericOp.getDpsInitsMutable())
58 newIndexingMaps.emplace_back(
59 genericOp.getMatchingIndexingMap(&opOperand));
61 Location loc = genericOp->getLoc();
62 SmallVector<Value> outputOperands = genericOp.getOutputs();
63 auto newOp = GenericOp::create(rewriter, loc, genericOp->getResultTypes(),
64 newOperands, outputOperands, newIndexingMaps,
65 genericOp.getIteratorTypesArray());
67 newOp.getRegion().begin());
69 Block *body = newOp.getBody();
70 PatternRewriter::InsertionGuard guard(rewriter);
73 for (
auto idx : llvm::reverse(scalarOperands)) {
74 OpOperand *opOperand = genericOp.getDpsInputOperand(idx);
75 AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
77 SmallVector<Value> indicesValues;
79 indicesValues.emplace_back(
81 Value scalarValue = opOperand->
get();
82 if (isa<RankedTensorType>(scalarValue.
getType())) {
83 scalarValue = tensor::ExtractOp::create(rewriter, loc, scalarValue,
90 rewriter.
replaceOp(genericOp, newOp->getResults());
100 auto *context =
patterns.getContext();
101 patterns.add<InlineScalarOperands>(context);
106struct LinalgInlineScalarOperandsPass
108 LinalgInlineScalarOperandsPass> {
110 LinalgInlineScalarOperandsPass>::LinalgInlineScalarOperandsPassBase;
111 void runOnOperation()
override {
bool isConstant() const
Returns true if this affine map has only constant results.
SmallVector< int64_t > getConstantResults() const
Returns the constant results of this map.
BlockArgument getArgument(unsigned i)
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Operation is the basic unit of execution within MLIR.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns)
Patterns that are used to inline constant operands into linalg generic ops.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...