MLIR 22.0.0git
OutlineShapeComputation.cpp
Go to the documentation of this file.
1//====----- OutlineShapeComputation.cpp -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/IRMapping.h"
17#include "llvm/ADT/DenseSet.h"
18#include "llvm/Support/Debug.h"
19#include <queue>
20#include <vector>
21
22namespace mlir {
23#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATIONPASS
24#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
25} // namespace mlir
26
27#define DEBUG_TYPE "outline-shape-computation"
28
29using namespace mlir;
30
31namespace {
33// A Value is an input of the cluster if it is an operand of an operation in the
34// cluster and its defining operation is not in the cluster.
36getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
38 llvm::SmallDenseSet<Value> inputSet;
39 llvm::SmallDenseSet<Operation *> opSet;
40 for (Operation *op : cluster) {
41 bool inserted = opSet.insert(op).second;
43 assert(inserted && "cluster contains duplicate operations");
44 }
45
46 for (Operation *op : cluster) {
47 for (Value operand : op->getOperands()) {
48 Operation *operandOp = operand.getDefiningOp();
49 if (opSet.contains(operandOp)) {
50 // Skip if defining op is in the cluster.
51 continue;
52 }
53 if (inputSet.insert(operand).second)
54 inputs.push_back(operand);
55 }
56 }
57 return inputs;
59
60// Create a shape.func representing the shape computation for `shape`.
61std::pair<shape::FuncOp, SmallVector<Value>>
62createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
63 Value shape, StringRef fnName, Location loc) {
64 SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
65 auto fnType =
66 cluster.empty()
67 ? b.getFunctionType(shape.getType(), shape.getType())
68 : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
69 shape::FuncOp fnOp = shape::FuncOp::create(b, loc, fnName, fnType);
70 Block *block = fnOp.addEntryBlock();
71 b.setInsertionPointToEnd(block);
72 IRMapping bvm;
73 if (cluster.empty()) {
74 bvm.map(shape, fnOp.getArgument(0));
75 } else {
76 for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
77 bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
78 }
79
80 for (Operation *op : cluster)
81 b.clone(*op, bvm);
83 fnReturns.push_back(bvm.lookupOrDefault(shape));
84
85 shape::ReturnOp::create(b, loc, fnReturns);
86 fnOp.setPrivate();
87 return std::make_pair(fnOp, inputs);
88}
89
90// The operations in the cluster might be unsorted, which could be inconvenient
91// when creating shape.func op.
93getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
94 func::FuncOp funcOp) {
95 // Compute all clusters that each operation is in
97 for (const auto &it : clusters) {
98 Value shape = it.first;
99 const DenseSet<Operation *> &cluster = it.second;
100 for (Operation *cOp : cluster)
101 op2Shapes[cOp].push_back(shape);
102 }
103
104 // Iterate through all operations in order. Get all the clusters `cOp` belongs
105 // to and construct the new ordered cluster as it traverses.
107 funcOp.walk([&](Operation *op) {
108 auto it = op2Shapes.find(op);
109 if (it != op2Shapes.end()) {
110 Operation *cOp = it->first;
111 for (Value shape : it->second)
112 orderedClusters[shape].push_back(cOp);
113 }
114 });
115
116 return orderedClusters;
117}
118
119void constructShapeFunc(
120 const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
122 SymbolTable &symbolTable,
124 func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
125 std::string shapeCalculationNamePrefix = "shape_cal_";
126 int shapeCalculationNameIdx = 0;
127 OpBuilder builder(context);
128
129 // Construct a shape function
130 for (shape::WithOp withOp : allWithOps) {
131 Value value = withOp.getOperand();
132 Value shape = withOp.getShape();
133 RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
134 if (rankedType == nullptr)
135 continue;
136
137 const SmallVector<Operation *, 8> &cluster = clusters[shape];
138 shape::ShapeMappingValue shapeMappingValue;
139 auto it = dynShape2ShapeFunc.find(shape);
140 if (it == dynShape2ShapeFunc.end()) {
141 std::string name = shapeCalculationNamePrefix +
142 std::to_string(shapeCalculationNameIdx++);
143 Location loc = value.getLoc();
144 builder.setInsertionPointAfter(funcOp);
145 auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
146 const SmallVector<Value> &inputs = pair.second;
147 shape::FuncOp shapeFuncOp = pair.first;
148 StringAttr insertedName = symbolTable.insert(shapeFuncOp);
149 auto symbol = FlatSymbolRefAttr::get(context, insertedName);
150
151 shapeMappingValue.funcSymbol = symbol;
152 shapeMappingValue.inputs = inputs;
153 } else {
154 shapeMappingValue = it->second;
155 }
156 dynShape2ShapeFunc[shape] = shapeMappingValue;
157 shapeMappingAnalysis.shapeMapping.insert(
158 std::make_pair(value, shapeMappingValue));
159 }
160}
161
162struct OutlineShapeComputationPass
164 OutlineShapeComputationPass> {
165
166 void runOnOperation() override;
167
168private:
169 bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
170
171 void getClusterFromValue(Value shape,
172 DenseMap<Value, DenseSet<Operation *>> &clusters);
173
175 constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
176 func::FuncOp funcOp);
177
178 DenseSet<Operation *> onlyUsedByWithShapes;
179};
180
181class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
182 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
183
184 LogicalResult matchAndRewrite(tensor::DimOp op,
185 PatternRewriter &rewriter) const override {
186 auto shapeOf =
187 shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getSource());
188 rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
189 op.getIndex());
190 return success();
191 }
192};
193
194void OutlineShapeComputationPass::runOnOperation() {
195 ModuleOp moduleOp = getOperation();
196 SymbolTable symbolTable(moduleOp);
198 auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
199 // TODO: This is as we populate this analysis during a pass that mutates. This
200 // pass currently requires 1 single module being compiled.
201 shapeMappingAnalysis.shapeMapping.clear();
202 markAnalysesPreserved<shape::ShapeMappingAnalysis>();
203
204 moduleOp.walk([&](func::FuncOp funcOp) {
205 MLIRContext *context = funcOp.getContext();
206 RewritePatternSet prevPatterns(context);
207 prevPatterns.insert<TensorDimOpRewriter>(context);
208 if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns))))
209 return signalPassFailure();
210
211 // initialize class member `onlyUsedByWithShapes`
212 onlyUsedByWithShapes.clear();
213 funcOp.walk([&](Operation *op) {
214 calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
215 });
216 LLVM_DEBUG({
217 llvm::dbgs() << "onlyUsedByWithShapes table: \n";
218 for (auto it : onlyUsedByWithShapes)
219 llvm::dbgs() << *it << "\n";
220 });
221
222 // collect all the shape.with_shape ops.
223 std::vector<shape::WithOp> allWithOps;
224 funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
225
227 constructClustersForEachShape(allWithOps, funcOp);
228 constructShapeFunc(allWithOps, context, clusters, symbolTable,
229 dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
230
231 for (shape::WithOp withOp : allWithOps) {
232 Value value = withOp.getOperand();
233 for (Operation *user :
234 llvm::make_early_inc_range(withOp.getResult().getUsers())) {
235 if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
236 // For pattern like
237 // %1 = shape.with_shape %arg1, %0
238 // %2 = shape.value_of %1
239 // because shape.value doesn't care the shape, the shape.with_shape is
240 // redundant.
241 // If type of %arg1 and %2 has same type, just
242 // replaced %2 with %arg1.
243 // If type of %arg1 has different type like !shape.value_shape,
244 // transform into
245 // %2 = shape.value_of %arg1
246 if (valueOf.getType() == value.getType())
247 valueOf.replaceAllUsesWith(value);
248 else
249 valueOf.setOperand(value);
250 }
251 }
252 }
253
254 // Apply patterns, note this also performs DCE.
255 if (failed(applyPatternsGreedily(funcOp, {})))
256 return signalPassFailure();
257 });
258}
259
261OutlineShapeComputationPass::constructClustersForEachShape(
262 const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
264 for (shape::WithOp withOp : allWithOps) {
265 Value shape = withOp.getShape();
266 if (clusters.count(shape) == 0)
267 getClusterFromValue(shape, clusters);
268 }
269 return getOrderedClusters(clusters, funcOp);
270}
271
272// The output of a cluster is the `shape`, and the inputs are the outputs of
273// operations who are not in `onlyUsedByWithShapes`
274void OutlineShapeComputationPass::getClusterFromValue(
275 Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
276 DenseSet<Operation *> cluster;
277
278 DenseSet<Operation *> visited;
279 std::queue<Operation *> queue;
280
281 // defOp == nullptr means shape is the argument of the func op
282 if (Operation *defOp = shape.getDefiningOp()) {
283 visited.insert(defOp);
284 queue.push(defOp);
285 }
286 while (!queue.empty()) {
287 Operation *op = queue.front();
288 queue.pop();
289 if (onlyUsedByWithShapes.contains(op)) {
290 cluster.insert(op);
291 for (Value inp : op->getOperands()) {
292 Operation *inpDefOp = inp.getDefiningOp();
293 if (nullptr != inpDefOp && visited.insert(inpDefOp).second)
294 queue.push(inpDefOp);
295 }
296 }
297 }
298
299 clusters[shape] = std::move(cluster);
300}
301
302// Returns whether `op` is a shape.with_shape, or all the users' of `op`
303// eventually point to the shape operand of shape.with_shape ops
304bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
305 Operation *op, Value prevOutput) {
306 if (onlyUsedByWithShapes.contains(op))
307 return true;
308
309 if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
310 return withOp.getShape() == prevOutput;
311
312 if (op->use_empty())
313 return false;
314
315 for (Value oup : op->getResults())
316 for (Operation *user : oup.getUsers())
317 if (!calOnlyUsedByWithShapesRecursively(user, oup))
318 return false;
319
320 onlyUsedByWithShapes.insert(op);
321 return true;
322}
323
324} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Block represents an ordered list of Operations.
Definition Block.h:33
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
ShapeMappingAnalysis is used together with OutlineShapeComputationPass to preserve Value and correspo...
llvm::DenseMap< Value, ShapeMappingValue > shapeMapping
ShapeMappingValue works as the value of ShapeMappingAnalysis table, where funcSymbol is the symbol of...
llvm::SmallVector< Value > inputs