MLIR  18.0.0git
InlineScalarOperands.cpp
Go to the documentation of this file.
1 //===- InlineScalarOperands.cpp - Pass to inline scalar operands =============//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to inline scalar operands into a generic
10 // operation. A scalar operand is an operand whose indexing map has a constant
11 // rhs.
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_LINALGINLINESCALAROPERANDS
27 #include "mlir/Dialect/Linalg/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
36  LogicalResult matchAndRewrite(GenericOp genericOp,
37  PatternRewriter &rewriter) const override {
38  if (!genericOp.hasTensorSemantics())
39  return failure();
40 
41  SmallVector<size_t> scalarOperands;
42  SmallVector<AffineMap> newIndexingMaps;
43  SmallVector<Value> newOperands;
44  for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
45  AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
46  if (genericOp.isDpsInput(opOperand) && map.isConstant()) {
47  scalarOperands.emplace_back(opOperand->getOperandNumber());
48  } else {
49  newIndexingMaps.emplace_back(map);
50  newOperands.emplace_back(opOperand->get());
51  }
52  }
53 
54  if (scalarOperands.empty())
55  return failure();
56 
57  for (OpOperand &opOperand : genericOp.getDpsInitsMutable())
58  newIndexingMaps.emplace_back(
59  genericOp.getMatchingIndexingMap(&opOperand));
60 
61  Location loc = genericOp->getLoc();
62  SmallVector<Value> outputOperands = genericOp.getOutputs();
63  auto newOp = rewriter.create<GenericOp>(
64  loc, genericOp->getResultTypes(), newOperands, outputOperands,
65  newIndexingMaps, genericOp.getIteratorTypesArray());
66  rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(),
67  newOp.getRegion().begin());
68 
69  Block *body = newOp.getBody();
70  PatternRewriter::InsertionGuard guard(rewriter);
71  rewriter.setInsertionPointToStart(body);
72 
73  for (auto idx : llvm::reverse(scalarOperands)) {
74  OpOperand *opOperand = genericOp.getDpsInputOperand(idx);
75  AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
77  SmallVector<Value> indicesValues;
78  for (auto idx : indices)
79  indicesValues.emplace_back(
80  rewriter.create<arith::ConstantIndexOp>(loc, idx));
81  Value extractedValue = rewriter.create<tensor::ExtractOp>(
82  loc, opOperand->get(), indicesValues);
83  body->getArgument(idx).replaceAllUsesWith(extractedValue);
84  body->eraseArgument(idx);
85  }
86 
87  rewriter.replaceOp(genericOp, newOp->getResults());
88  return success();
89  }
90 };
91 } // namespace
92 
93 /// Patterns that are used to inline constant operands into linalg generic
94 /// ops.
96  RewritePatternSet &patterns) {
97  auto *context = patterns.getContext();
98  patterns.add<InlineScalarOperands>(context);
99 }
100 
101 namespace {
102 /// Pass that removes unit-extent dims within generic ops.
103 struct LinalgInlineScalarOperandsPass
104  : public impl::LinalgInlineScalarOperandsBase<
105  LinalgInlineScalarOperandsPass> {
106  void runOnOperation() override {
107  Operation *op = getOperation();
108  MLIRContext &ctx = getContext();
109  RewritePatternSet patterns(&ctx);
111  (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
112  }
113 };
114 } // namespace
115 
117  return std::make_unique<LinalgInlineScalarOperandsPass>();
118 }
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
bool isConstant() const
Returns true if this affine map has only constant results.
Definition: AffineMap.cpp:318
SmallVector< int64_t > getConstantResults() const
Returns the constant results of this map.
Definition: AffineMap.cpp:329
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
void eraseArgument(unsigned index)
Erase the argument at 'index' and remove it from the argument list.
Definition: Block.cpp:187
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:150
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents an operand of an operation.
Definition: Value.h:261
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
void replaceAllUsesWith(Value newValue) const
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:166
void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns)
Patterns that are used to inline constant operands into linalg generic ops.
This header declares functions that assist transformations in the MemRef dialect.
std::unique_ptr< Pass > createLinalgInlineScalarOperandsPass()
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357