26 #define GEN_PASS_DEF_LINALGINLINESCALAROPERANDS
27 #include "mlir/Dialect/Linalg/Passes.h.inc"
38 if (!genericOp.hasTensorSemantics())
44 for (
OpOperand *opOperand : genericOp.getDpsInputOperands()) {
45 AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
46 if (genericOp.isDpsInput(opOperand) && map.
isConstant()) {
47 scalarOperands.emplace_back(opOperand->getOperandNumber());
49 newIndexingMaps.emplace_back(map);
50 newOperands.emplace_back(opOperand->get());
54 if (scalarOperands.empty())
57 for (
OpOperand &opOperand : genericOp.getDpsInitsMutable())
58 newIndexingMaps.emplace_back(
59 genericOp.getMatchingIndexingMap(&opOperand));
63 auto newOp = rewriter.
create<GenericOp>(
64 loc, genericOp->getResultTypes(), newOperands, outputOperands,
65 newIndexingMaps, genericOp.getIteratorTypesArray());
67 newOp.getRegion().begin());
69 Block *body = newOp.getBody();
70 PatternRewriter::InsertionGuard guard(rewriter);
73 for (
auto idx : llvm::reverse(scalarOperands)) {
74 OpOperand *opOperand = genericOp.getDpsInputOperand(idx);
75 AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
78 for (
auto idx : indices)
79 indicesValues.emplace_back(
80 rewriter.
create<arith::ConstantIndexOp>(loc, idx));
81 Value extractedValue = rewriter.
create<tensor::ExtractOp>(
82 loc, opOperand->
get(), indicesValues);
87 rewriter.
replaceOp(genericOp, newOp->getResults());
98 patterns.
add<InlineScalarOperands>(context);
103 struct LinalgInlineScalarOperandsPass
104 :
public impl::LinalgInlineScalarOperandsBase<
105 LinalgInlineScalarOperandsPass> {
106 void runOnOperation()
override {
117 return std::make_unique<LinalgInlineScalarOperandsPass>();
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isConstant() const
Returns true if this affine map has only constant results.
SmallVector< int64_t > getConstantResults() const
Returns the constant results of this map.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void cloneRegionBefore(Region ®ion, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns)
Patterns that are used to inline constant operands into linalg generic ops.
This header declares functions that assist transformations in the MemRef dialect.
std::unique_ptr< Pass > createLinalgInlineScalarOperandsPass()
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...