MLIR  21.0.0git
OutlineShapeComputation.cpp
Go to the documentation of this file.
1 //====----- OutlineShapeComputation.cpp -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/Support/Debug.h"
21 #include <queue>
22 #include <unordered_set>
23 #include <vector>
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATIONPASS
27 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 #define DEBUG_TYPE "outline-shape-computation"
31 
32 using namespace mlir;
33 
34 namespace {
35 
36 // A Value is an input of the cluster if it is an operand of an operation in the
37 // cluster and its defining operation is not in the cluster.
39 getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
40  SmallVector<Value, 4> inputs;
41  llvm::SmallDenseSet<Value> inputSet;
42  llvm::SmallDenseSet<Operation *> opSet;
43  for (Operation *op : cluster) {
44  bool inserted = opSet.insert(op).second;
45  (void)inserted;
46  assert(inserted && "cluster contains duplicate operations");
47  }
48 
49  for (Operation *op : cluster) {
50  for (Value operand : op->getOperands()) {
51  Operation *operandOp = operand.getDefiningOp();
52  if (opSet.contains(operandOp)) {
53  // Skip if defining op is in the cluster.
54  continue;
55  }
56  if (inputSet.insert(operand).second)
57  inputs.push_back(operand);
58  }
59  }
60  return inputs;
61 }
62 
63 // Create a shape.func representing the shape computation for `shape`.
64 std::pair<shape::FuncOp, SmallVector<Value>>
65 createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
66  Value shape, StringRef fnName, Location loc) {
67  SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
68  auto fnType =
69  cluster.empty()
70  ? b.getFunctionType(shape.getType(), shape.getType())
71  : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
72  shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
73  Block *block = fnOp.addEntryBlock();
74  b.setInsertionPointToEnd(block);
75  IRMapping bvm;
76  if (cluster.empty()) {
77  bvm.map(shape, fnOp.getArgument(0));
78  } else {
79  for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
80  bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
81  }
82 
83  for (Operation *op : cluster)
84  b.clone(*op, bvm);
86  fnReturns.push_back(bvm.lookupOrDefault(shape));
87 
88  b.create<shape::ReturnOp>(loc, fnReturns);
89  fnOp.setPrivate();
90  return std::make_pair(fnOp, inputs);
91 }
92 
93 // The operations in the cluster might be unsorted, which could be inconvenient
94 // when creating shape.func op.
96 getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
97  func::FuncOp funcOp) {
98  // Compute all clusters that each operation is in
100  for (const auto &it : clusters) {
101  Value shape = it.first;
102  const DenseSet<Operation *> &cluster = it.second;
103  for (Operation *cOp : cluster)
104  op2Shapes[cOp].push_back(shape);
105  }
106 
107  // Iterate through all operations in order. Get all the clusters `cOp` belongs
108  // to and construct the new ordered cluster as it traverses.
110  funcOp.walk([&](Operation *op) {
111  auto it = op2Shapes.find(op);
112  if (it != op2Shapes.end()) {
113  Operation *cOp = it->first;
114  for (Value shape : it->second)
115  orderedClusters[shape].push_back(cOp);
116  }
117  });
118 
119  return orderedClusters;
120 }
121 
122 void constructShapeFunc(
123  const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
125  SymbolTable &symbolTable,
126  DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
127  func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
128  std::string shapeCalculationNamePrefix = "shape_cal_";
129  int shapeCalculationNameIdx = 0;
130  OpBuilder builder(context);
131 
132  // Construct a shape function
133  for (shape::WithOp withOp : allWithOps) {
134  Value value = withOp.getOperand();
135  Value shape = withOp.getShape();
136  RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
137  if (rankedType == nullptr)
138  continue;
139 
140  const SmallVector<Operation *, 8> &cluster = clusters[shape];
141  shape::ShapeMappingValue shapeMappingValue;
142  auto it = dynShape2ShapeFunc.find(shape);
143  if (it == dynShape2ShapeFunc.end()) {
144  std::string name = shapeCalculationNamePrefix +
145  std::to_string(shapeCalculationNameIdx++);
146  Location loc = value.getLoc();
147  builder.setInsertionPointAfter(funcOp);
148  auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
149  const SmallVector<Value> &inputs = pair.second;
150  shape::FuncOp shapeFuncOp = pair.first;
151  StringAttr insertedName = symbolTable.insert(shapeFuncOp);
152  auto symbol = FlatSymbolRefAttr::get(context, insertedName);
153 
154  shapeMappingValue.funcSymbol = symbol;
155  shapeMappingValue.inputs = inputs;
156  } else {
157  shapeMappingValue = it->second;
158  }
159  dynShape2ShapeFunc[shape] = shapeMappingValue;
160  shapeMappingAnalysis.shapeMapping.insert(
161  std::make_pair(value, shapeMappingValue));
162  }
163 }
164 
165 struct OutlineShapeComputationPass
166  : public impl::OutlineShapeComputationPassBase<
167  OutlineShapeComputationPass> {
168 
169  void runOnOperation() override;
170 
171 private:
172  bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
173 
174  void getClusterFromValue(Value shape,
175  DenseMap<Value, DenseSet<Operation *>> &clusters);
176 
178  constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
179  func::FuncOp funcOp);
180 
181  DenseSet<Operation *> onlyUsedByWithShapes;
182 };
183 
184 class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
186 
187  LogicalResult matchAndRewrite(tensor::DimOp op,
188  PatternRewriter &rewriter) const override {
189  auto shapeOf =
190  rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
191  rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
192  op.getIndex());
193  return success();
194  }
195 };
196 
197 void OutlineShapeComputationPass::runOnOperation() {
198  ModuleOp moduleOp = getOperation();
199  SymbolTable symbolTable(moduleOp);
201  auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
202  // TODO: This is as we populate this analysis during a pass that mutates. This
203  // pass currently requires 1 single module being compiled.
204  shapeMappingAnalysis.shapeMapping.clear();
205  markAnalysesPreserved<shape::ShapeMappingAnalysis>();
206 
207  moduleOp.walk([&](func::FuncOp funcOp) {
208  MLIRContext *context = funcOp.getContext();
209  RewritePatternSet prevPatterns(context);
210  prevPatterns.insert<TensorDimOpRewriter>(context);
211  if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns))))
212  return signalPassFailure();
213 
214  // initialize class member `onlyUsedByWithShapes`
215  onlyUsedByWithShapes.clear();
216  funcOp.walk([&](Operation *op) {
217  calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
218  });
219  LLVM_DEBUG({
220  llvm::dbgs() << "onlyUsedByWithShapes table: \n";
221  for (auto it : onlyUsedByWithShapes)
222  llvm::dbgs() << *it << "\n";
223  });
224 
225  // collect all the shape.with_shape ops.
226  std::vector<shape::WithOp> allWithOps;
227  funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
228 
230  constructClustersForEachShape(allWithOps, funcOp);
231  constructShapeFunc(allWithOps, context, clusters, symbolTable,
232  dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
233 
234  for (shape::WithOp withOp : allWithOps) {
235  Value value = withOp.getOperand();
236  for (Operation *user :
237  llvm::make_early_inc_range(withOp.getResult().getUsers())) {
238  if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
239  // For pattern like
240  // %1 = shape.with_shape %arg1, %0
241  // %2 = shape.value_of %1
242  // because shape.value doesn't care the shape, the shape.with_shape is
243  // redundant.
244  // If type of %arg1 and %2 has same type, just
245  // replaced %2 with %arg1.
246  // If type of %arg1 has different type like !shape.value_shape,
247  // transform into
248  // %2 = shape.value_of %arg1
249  if (valueOf.getType() == value.getType())
250  valueOf.replaceAllUsesWith(value);
251  else
252  valueOf.setOperand(value);
253  }
254  }
255  }
256 
257  // Apply patterns, note this also performs DCE.
258  if (failed(applyPatternsGreedily(funcOp, {})))
259  return signalPassFailure();
260  });
261 }
262 
264 OutlineShapeComputationPass::constructClustersForEachShape(
265  const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
267  for (shape::WithOp withOp : allWithOps) {
268  Value shape = withOp.getShape();
269  if (clusters.count(shape) == 0)
270  getClusterFromValue(shape, clusters);
271  }
272  return getOrderedClusters(clusters, funcOp);
273 }
274 
275 // The output of a cluster is the `shape`, and the inputs are the outputs of
276 // operations who are not in `onlyUsedByWithShapes`
277 void OutlineShapeComputationPass::getClusterFromValue(
278  Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
279  DenseSet<Operation *> cluster;
280 
281  DenseSet<Operation *> visited;
282  std::queue<Operation *> queue;
283 
284  // defOp == nullptr means shape is the argument of the func op
285  if (Operation *defOp = shape.getDefiningOp()) {
286  visited.insert(defOp);
287  queue.push(defOp);
288  }
289  while (!queue.empty()) {
290  Operation *op = queue.front();
291  queue.pop();
292  if (onlyUsedByWithShapes.contains(op)) {
293  cluster.insert(op);
294  for (Value inp : op->getOperands()) {
295  Operation *inpDefOp = inp.getDefiningOp();
296  if (nullptr != inpDefOp && visited.insert(inpDefOp).second)
297  queue.push(inpDefOp);
298  }
299  }
300  }
301 
302  clusters[shape] = std::move(cluster);
303 }
304 
305 // Returns whether `op` is a shape.with_shape, or all the users' of `op`
306 // eventually point to the shape operand of shape.with_shape ops
307 bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
308  Operation *op, Value prevOutput) {
309  if (onlyUsedByWithShapes.contains(op))
310  return true;
311 
312  if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
313  return withOp.getShape() == prevOutput;
314 
315  if (op->use_empty())
316  return false;
317 
318  for (Value oup : op->getResults())
319  for (Operation *user : oup.getUsers())
320  if (!calOnlyUsedByWithShapesRecursively(user, oup))
321  return false;
322 
323  onlyUsedByWithShapes.insert(op);
324  return true;
325 }
326 
327 } // namespace
Block represents an ordered list of Operations.
Definition: Block.h:33
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:853
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
result_range getResults()
Definition: Operation.h:415
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:895
void clear()
Clear out all of the held patterns in this list.
Definition: PatternMatch.h:796
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:149
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
ShapeMappingAnalysis is used together with OutlineShapeComputationPass to preserve Value and correspo...
llvm::DenseMap< Value, ShapeMappingValue > shapeMapping
ShapeMappingValue works as the value of ShapeMappingAnalysis table, where funcSymbol is the symbol of...
llvm::SmallVector< Value > inputs