MLIR  22.0.0git
OutlineShapeComputation.cpp
Go to the documentation of this file.
1 //====----- OutlineShapeComputation.cpp -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/IRMapping.h"
17 #include "llvm/ADT/DenseSet.h"
18 #include "llvm/Support/Debug.h"
19 #include <queue>
20 #include <vector>
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATIONPASS
24 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
25 } // namespace mlir
26 
27 #define DEBUG_TYPE "outline-shape-computation"
28 
29 using namespace mlir;
30 
31 namespace {
32 
33 // A Value is an input of the cluster if it is an operand of an operation in the
34 // cluster and its defining operation is not in the cluster.
36 getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
37  SmallVector<Value, 4> inputs;
38  llvm::SmallDenseSet<Value> inputSet;
39  llvm::SmallDenseSet<Operation *> opSet;
40  for (Operation *op : cluster) {
41  bool inserted = opSet.insert(op).second;
42  (void)inserted;
43  assert(inserted && "cluster contains duplicate operations");
44  }
45 
46  for (Operation *op : cluster) {
47  for (Value operand : op->getOperands()) {
48  Operation *operandOp = operand.getDefiningOp();
49  if (opSet.contains(operandOp)) {
50  // Skip if defining op is in the cluster.
51  continue;
52  }
53  if (inputSet.insert(operand).second)
54  inputs.push_back(operand);
55  }
56  }
57  return inputs;
58 }
59 
60 // Create a shape.func representing the shape computation for `shape`.
61 std::pair<shape::FuncOp, SmallVector<Value>>
62 createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
63  Value shape, StringRef fnName, Location loc) {
64  SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
65  auto fnType =
66  cluster.empty()
67  ? b.getFunctionType(shape.getType(), shape.getType())
68  : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
69  shape::FuncOp fnOp = shape::FuncOp::create(b, loc, fnName, fnType);
70  Block *block = fnOp.addEntryBlock();
71  b.setInsertionPointToEnd(block);
72  IRMapping bvm;
73  if (cluster.empty()) {
74  bvm.map(shape, fnOp.getArgument(0));
75  } else {
76  for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
77  bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
78  }
79 
80  for (Operation *op : cluster)
81  b.clone(*op, bvm);
83  fnReturns.push_back(bvm.lookupOrDefault(shape));
84 
85  shape::ReturnOp::create(b, loc, fnReturns);
86  fnOp.setPrivate();
87  return std::make_pair(fnOp, inputs);
88 }
89 
90 // The operations in the cluster might be unsorted, which could be inconvenient
91 // when creating shape.func op.
93 getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
94  func::FuncOp funcOp) {
95  // Compute all clusters that each operation is in
97  for (const auto &it : clusters) {
98  Value shape = it.first;
99  const DenseSet<Operation *> &cluster = it.second;
100  for (Operation *cOp : cluster)
101  op2Shapes[cOp].push_back(shape);
102  }
103 
104  // Iterate through all operations in order. Get all the clusters `cOp` belongs
105  // to and construct the new ordered cluster as it traverses.
107  funcOp.walk([&](Operation *op) {
108  auto it = op2Shapes.find(op);
109  if (it != op2Shapes.end()) {
110  Operation *cOp = it->first;
111  for (Value shape : it->second)
112  orderedClusters[shape].push_back(cOp);
113  }
114  });
115 
116  return orderedClusters;
117 }
118 
119 void constructShapeFunc(
120  const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
122  SymbolTable &symbolTable,
123  DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
124  func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
125  std::string shapeCalculationNamePrefix = "shape_cal_";
126  int shapeCalculationNameIdx = 0;
127  OpBuilder builder(context);
128 
129  // Construct a shape function
130  for (shape::WithOp withOp : allWithOps) {
131  Value value = withOp.getOperand();
132  Value shape = withOp.getShape();
133  RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
134  if (rankedType == nullptr)
135  continue;
136 
137  const SmallVector<Operation *, 8> &cluster = clusters[shape];
138  shape::ShapeMappingValue shapeMappingValue;
139  auto it = dynShape2ShapeFunc.find(shape);
140  if (it == dynShape2ShapeFunc.end()) {
141  std::string name = shapeCalculationNamePrefix +
142  std::to_string(shapeCalculationNameIdx++);
143  Location loc = value.getLoc();
144  builder.setInsertionPointAfter(funcOp);
145  auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
146  const SmallVector<Value> &inputs = pair.second;
147  shape::FuncOp shapeFuncOp = pair.first;
148  StringAttr insertedName = symbolTable.insert(shapeFuncOp);
149  auto symbol = FlatSymbolRefAttr::get(context, insertedName);
150 
151  shapeMappingValue.funcSymbol = symbol;
152  shapeMappingValue.inputs = inputs;
153  } else {
154  shapeMappingValue = it->second;
155  }
156  dynShape2ShapeFunc[shape] = shapeMappingValue;
157  shapeMappingAnalysis.shapeMapping.insert(
158  std::make_pair(value, shapeMappingValue));
159  }
160 }
161 
162 struct OutlineShapeComputationPass
163  : public impl::OutlineShapeComputationPassBase<
164  OutlineShapeComputationPass> {
165 
166  void runOnOperation() override;
167 
168 private:
169  bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
170 
171  void getClusterFromValue(Value shape,
172  DenseMap<Value, DenseSet<Operation *>> &clusters);
173 
175  constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
176  func::FuncOp funcOp);
177 
178  DenseSet<Operation *> onlyUsedByWithShapes;
179 };
180 
181 class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
183 
184  LogicalResult matchAndRewrite(tensor::DimOp op,
185  PatternRewriter &rewriter) const override {
186  auto shapeOf =
187  shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getSource());
188  rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
189  op.getIndex());
190  return success();
191  }
192 };
193 
194 void OutlineShapeComputationPass::runOnOperation() {
195  ModuleOp moduleOp = getOperation();
196  SymbolTable symbolTable(moduleOp);
198  auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
199  // TODO: This is as we populate this analysis during a pass that mutates. This
200  // pass currently requires 1 single module being compiled.
201  shapeMappingAnalysis.shapeMapping.clear();
202  markAnalysesPreserved<shape::ShapeMappingAnalysis>();
203 
204  moduleOp.walk([&](func::FuncOp funcOp) {
205  MLIRContext *context = funcOp.getContext();
206  RewritePatternSet prevPatterns(context);
207  prevPatterns.insert<TensorDimOpRewriter>(context);
208  if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns))))
209  return signalPassFailure();
210 
211  // initialize class member `onlyUsedByWithShapes`
212  onlyUsedByWithShapes.clear();
213  funcOp.walk([&](Operation *op) {
214  calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
215  });
216  LLVM_DEBUG({
217  llvm::dbgs() << "onlyUsedByWithShapes table: \n";
218  for (auto it : onlyUsedByWithShapes)
219  llvm::dbgs() << *it << "\n";
220  });
221 
222  // collect all the shape.with_shape ops.
223  std::vector<shape::WithOp> allWithOps;
224  funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
225 
227  constructClustersForEachShape(allWithOps, funcOp);
228  constructShapeFunc(allWithOps, context, clusters, symbolTable,
229  dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
230 
231  for (shape::WithOp withOp : allWithOps) {
232  Value value = withOp.getOperand();
233  for (Operation *user :
234  llvm::make_early_inc_range(withOp.getResult().getUsers())) {
235  if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
236  // For pattern like
237  // %1 = shape.with_shape %arg1, %0
238  // %2 = shape.value_of %1
239  // because shape.value doesn't care the shape, the shape.with_shape is
240  // redundant.
241  // If type of %arg1 and %2 has same type, just
242  // replaced %2 with %arg1.
243  // If type of %arg1 has different type like !shape.value_shape,
244  // transform into
245  // %2 = shape.value_of %arg1
246  if (valueOf.getType() == value.getType())
247  valueOf.replaceAllUsesWith(value);
248  else
249  valueOf.setOperand(value);
250  }
251  }
252  }
253 
254  // Apply patterns, note this also performs DCE.
255  if (failed(applyPatternsGreedily(funcOp, {})))
256  return signalPassFailure();
257  });
258 }
259 
261 OutlineShapeComputationPass::constructClustersForEachShape(
262  const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
264  for (shape::WithOp withOp : allWithOps) {
265  Value shape = withOp.getShape();
266  if (clusters.count(shape) == 0)
267  getClusterFromValue(shape, clusters);
268  }
269  return getOrderedClusters(clusters, funcOp);
270 }
271 
272 // The output of a cluster is the `shape`, and the inputs are the outputs of
273 // operations who are not in `onlyUsedByWithShapes`
274 void OutlineShapeComputationPass::getClusterFromValue(
275  Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
276  DenseSet<Operation *> cluster;
277 
278  DenseSet<Operation *> visited;
279  std::queue<Operation *> queue;
280 
281  // defOp == nullptr means shape is the argument of the func op
282  if (Operation *defOp = shape.getDefiningOp()) {
283  visited.insert(defOp);
284  queue.push(defOp);
285  }
286  while (!queue.empty()) {
287  Operation *op = queue.front();
288  queue.pop();
289  if (onlyUsedByWithShapes.contains(op)) {
290  cluster.insert(op);
291  for (Value inp : op->getOperands()) {
292  Operation *inpDefOp = inp.getDefiningOp();
293  if (nullptr != inpDefOp && visited.insert(inpDefOp).second)
294  queue.push(inpDefOp);
295  }
296  }
297  }
298 
299  clusters[shape] = std::move(cluster);
300 }
301 
302 // Returns whether `op` is a shape.with_shape, or all the users' of `op`
303 // eventually point to the shape operand of shape.with_shape ops
304 bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
305  Operation *op, Value prevOutput) {
306  if (onlyUsedByWithShapes.contains(op))
307  return true;
308 
309  if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
310  return withOp.getShape() == prevOutput;
311 
312  if (op->use_empty())
313  return false;
314 
315  for (Value oup : op->getResults())
316  for (Operation *user : oup.getUsers())
317  if (!calOnlyUsedByWithShapesRecursively(user, oup))
318  return false;
319 
320  onlyUsedByWithShapes.insert(op);
321  return true;
322 }
323 
324 } // namespace
Block represents an ordered list of Operations.
Definition: Block.h:33
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:914
void clear()
Clear out all of the held patterns in this list.
Definition: PatternMatch.h:816
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:149
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
ShapeMappingAnalysis is used together with OutlineShapeComputationPass to preserve Value and correspo...
llvm::DenseMap< Value, ShapeMappingValue > shapeMapping
ShapeMappingValue works as the value of ShapeMappingAnalysis table, where funcSymbol is the symbol of...
llvm::SmallVector< Value > inputs