MLIR  19.0.0git
InlineScalarOperands.cpp
Go to the documentation of this file.
1 //===- InlineScalarOperands.cpp - Pass to inline scalar operands =============//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to inline scalar operands into a generic
10 // operation. A scalar operand is an operand whose indexing map has a constant
11 // rhs.
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_LINALGINLINESCALAROPERANDSPASS
27 #include "mlir/Dialect/Linalg/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
36  LogicalResult matchAndRewrite(GenericOp genericOp,
37  PatternRewriter &rewriter) const override {
38  if (!genericOp.hasPureTensorSemantics())
39  return failure();
40 
41  SmallVector<size_t> scalarOperands;
42  SmallVector<AffineMap> newIndexingMaps;
43  SmallVector<Value> newOperands;
44  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
45  AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
46  if (genericOp.isDpsInput(opOperand) && map.isConstant()) {
47  scalarOperands.emplace_back(opOperand->getOperandNumber());
48  } else {
49  newIndexingMaps.emplace_back(map);
50  newOperands.emplace_back(opOperand->get());
51  }
52  }
53 
54  if (scalarOperands.empty())
55  return failure();
56 
57  for (OpOperand &opOperand : genericOp.getDpsInitsMutable())
58  newIndexingMaps.emplace_back(
59  genericOp.getMatchingIndexingMap(&opOperand));
60 
61  Location loc = genericOp->getLoc();
62  SmallVector<Value> outputOperands = genericOp.getOutputs();
63  auto newOp = rewriter.create<GenericOp>(
64  loc, genericOp->getResultTypes(), newOperands, outputOperands,
65  newIndexingMaps, genericOp.getIteratorTypesArray());
66  rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(),
67  newOp.getRegion().begin());
68 
69  Block *body = newOp.getBody();
70  PatternRewriter::InsertionGuard guard(rewriter);
71  rewriter.setInsertionPointToStart(body);
72 
73  for (auto idx : llvm::reverse(scalarOperands)) {
74  OpOperand *opOperand = genericOp.getDpsInputOperand(idx);
75  AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
77  SmallVector<Value> indicesValues;
78  for (auto idx : indices)
79  indicesValues.emplace_back(
80  rewriter.create<arith::ConstantIndexOp>(loc, idx));
81  Value extractedValue = rewriter.create<tensor::ExtractOp>(
82  loc, opOperand->get(), indicesValues);
83  body->getArgument(idx).replaceAllUsesWith(extractedValue);
84  body->eraseArgument(idx);
85  }
86 
87  rewriter.replaceOp(genericOp, newOp->getResults());
88  return success();
89  }
90 };
91 } // namespace
92 
93 /// Patterns that are used to inline constant operands into linalg generic
94 /// ops.
96  RewritePatternSet &patterns) {
97  auto *context = patterns.getContext();
98  patterns.add<InlineScalarOperands>(context);
99 }
100 
101 namespace {
102 /// Pass that removes unit-extent dims within generic ops.
103 struct LinalgInlineScalarOperandsPass
104  : public impl::LinalgInlineScalarOperandsPassBase<
105  LinalgInlineScalarOperandsPass> {
106  using impl::LinalgInlineScalarOperandsPassBase<
107  LinalgInlineScalarOperandsPass>::LinalgInlineScalarOperandsPassBase;
108  void runOnOperation() override {
109  Operation *op = getOperation();
110  MLIRContext &ctx = getContext();
111  RewritePatternSet patterns(&ctx);
113  (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
114  }
115 };
116 } // namespace
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
bool isConstant() const
Returns true if this affine map has only constant results.
Definition: AffineMap.cpp:361
SmallVector< int64_t > getConstantResults() const
Returns the constant results of this map.
Definition: AffineMap.cpp:370
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: Block.cpp:192
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition: Builders.cpp:580
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:169
void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns)
Patterns that are used to inline constant operands into linalg generic ops.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358