MLIR  20.0.0git
OutlineShapeComputation.cpp
Go to the documentation of this file.
1 //====----- OutlineShapeComputation.cpp -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/Support/Debug.h"
21 #include <queue>
22 #include <unordered_set>
23 #include <vector>
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
27 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 #define DEBUG_TYPE "outline-shape-computation"
31 
32 using namespace mlir;
33 
34 namespace {
35 
36 // A Value is an input of the cluster if it is an operand of an operation in the
37 // cluster and its defining operation is not in the cluster.
39 getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
40  SmallVector<Value, 4> inputs;
41  llvm::SmallDenseSet<Value> inputSet;
42  llvm::SmallDenseSet<Operation *> opSet;
43  for (Operation *op : cluster) {
44  bool inserted = opSet.insert(op).second;
45  (void)inserted;
46  assert(inserted && "cluster contains duplicate operations");
47  }
48 
49  for (Operation *op : cluster) {
50  for (Value operand : op->getOperands()) {
51  Operation *operandOp = operand.getDefiningOp();
52  if (opSet.contains(operandOp)) {
53  // Skip if defining op is in the cluster.
54  continue;
55  }
56  if (inputSet.insert(operand).second)
57  inputs.push_back(operand);
58  }
59  }
60  return inputs;
61 }
62 
63 // Create a shape.func representing the shape computation for `shape`.
64 std::pair<shape::FuncOp, SmallVector<Value>>
65 createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
66  Value shape, StringRef fnName, Location loc) {
67  SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
68  auto fnType =
69  cluster.empty()
70  ? b.getFunctionType(shape.getType(), shape.getType())
71  : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
72  shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
73  Block *block = fnOp.addEntryBlock();
74  b.setInsertionPoint(block, block->end());
75  IRMapping bvm;
76  if (cluster.empty()) {
77  bvm.map(shape, fnOp.getArgument(0));
78  } else {
79  for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
80  bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
81  }
82 
83  for (Operation *op : cluster)
84  b.clone(*op, bvm);
86  fnReturns.push_back(bvm.lookupOrDefault(shape));
87 
88  b.create<shape::ReturnOp>(loc, fnReturns);
89  fnOp.setPrivate();
90  return std::make_pair(fnOp, inputs);
91 }
92 
93 // The operations in the cluster might be unsorted, which could be inconvenient
94 // when creating shape.func op.
96 getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
97  func::FuncOp funcOp) {
98  // Compute all clusters that each operation is in
100  for (const auto &it : clusters) {
101  Value shape = it.first;
102  const DenseSet<Operation *> &cluster = it.second;
103  for (Operation *cOp : cluster)
104  op2Shapes[cOp].push_back(shape);
105  }
106 
107  // Iterate through all operations in order. Get all the clusters `cOp` belongs
108  // to and construct the new ordered cluster as it traverses.
110  funcOp.walk([&](Operation *op) {
111  auto it = op2Shapes.find(op);
112  if (it != op2Shapes.end()) {
113  Operation *cOp = it->first;
114  for (Value shape : it->second)
115  orderedClusters[shape].push_back(cOp);
116  }
117  });
118 
119  return orderedClusters;
120 }
121 
122 void constructShapeFunc(
123  const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
125  SymbolTable &symbolTable,
126  DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
127  func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
128  std::string shapeCalculationNamePrefix = "shape_cal_";
129  int shapeCalculationNameIdx = 0;
130  OpBuilder builder(context);
131 
132  // Construct a shape function
133  for (shape::WithOp withOp : allWithOps) {
134  Value value = withOp.getOperand();
135  Value shape = withOp.getShape();
136  RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
137  if (rankedType == nullptr)
138  continue;
139 
140  const SmallVector<Operation *, 8> &cluster = clusters[shape];
141  shape::ShapeMappingValue shapeMappingValue;
142  auto it = dynShape2ShapeFunc.find(shape);
143  if (it == dynShape2ShapeFunc.end()) {
144  std::string name = shapeCalculationNamePrefix +
145  std::to_string(shapeCalculationNameIdx++);
146  Location loc = value.getLoc();
147  builder.setInsertionPointAfter(funcOp);
148  auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
149  const SmallVector<Value> &inputs = pair.second;
150  shape::FuncOp shapeFuncOp = pair.first;
151  StringAttr insertedName = symbolTable.insert(shapeFuncOp);
152  auto symbol = FlatSymbolRefAttr::get(context, insertedName);
153 
154  shapeMappingValue.funcSymbol = symbol;
155  shapeMappingValue.inputs = inputs;
156  } else {
157  shapeMappingValue = it->second;
158  }
159  dynShape2ShapeFunc[shape] = shapeMappingValue;
160  shapeMappingAnalysis.shapeMapping.insert(
161  std::make_pair(value, shapeMappingValue));
162  }
163 }
164 
165 struct OutlineShapeComputationPass
166  : public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
167 
168  void runOnOperation() override;
169 
170 private:
171  bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
172 
173  void getClusterFromValue(Value shape,
174  DenseMap<Value, DenseSet<Operation *>> &clusters);
175 
177  constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
178  func::FuncOp funcOp);
179 
180  DenseSet<Operation *> onlyUsedByWithShapes;
181 };
182 
183 class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
185 
186  LogicalResult matchAndRewrite(tensor::DimOp op,
187  PatternRewriter &rewriter) const override {
188  auto shapeOf =
189  rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
190  rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
191  op.getIndex());
192  return success();
193  }
194 };
195 
196 void OutlineShapeComputationPass::runOnOperation() {
197  ModuleOp moduleOp = getOperation();
198  SymbolTable symbolTable(moduleOp);
200  auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
201  // TODO: This is as we populate this analysis during a pass that mutates. This
202  // pass currently requires 1 single module being compiled.
203  shapeMappingAnalysis.shapeMapping.clear();
204  markAnalysesPreserved<shape::ShapeMappingAnalysis>();
205 
206  moduleOp.walk([&](func::FuncOp funcOp) {
207  MLIRContext *context = funcOp.getContext();
208  RewritePatternSet prevPatterns(context);
209  prevPatterns.insert<TensorDimOpRewriter>(context);
210  if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns))))
211  return signalPassFailure();
212 
213  // initialize class member `onlyUsedByWithShapes`
214  onlyUsedByWithShapes.clear();
215  funcOp.walk([&](Operation *op) {
216  calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
217  });
218  LLVM_DEBUG({
219  llvm::dbgs() << "onlyUsedByWithShapes table: \n";
220  for (auto it : onlyUsedByWithShapes)
221  llvm::dbgs() << *it << "\n";
222  });
223 
224  // collect all the shape.with_shape ops.
225  std::vector<shape::WithOp> allWithOps;
226  funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
227 
229  constructClustersForEachShape(allWithOps, funcOp);
230  constructShapeFunc(allWithOps, context, clusters, symbolTable,
231  dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
232 
233  for (shape::WithOp withOp : allWithOps) {
234  Value value = withOp.getOperand();
235  for (Operation *user :
236  llvm::make_early_inc_range(withOp.getResult().getUsers())) {
237  if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
238  // For pattern like
239  // %1 = shape.with_shape %arg1, %0
240  // %2 = shape.value_of %1
241  // because shape.value doesn't care the shape, the shape.with_shape is
242  // redundant.
243  // If type of %arg1 and %2 has same type, just
244  // replaced %2 with %arg1.
245  // If type of %arg1 has different type like !shape.value_shape,
246  // transform into
247  // %2 = shape.value_of %arg1
248  if (valueOf.getType() == value.getType())
249  valueOf.replaceAllUsesWith(value);
250  else
251  valueOf.setOperand(value);
252  }
253  }
254  }
255 
256  // Apply patterns, note this also performs DCE.
257  if (failed(applyPatternsAndFoldGreedily(funcOp, {})))
258  return signalPassFailure();
259  });
260 }
261 
263 OutlineShapeComputationPass::constructClustersForEachShape(
264  const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
266  for (shape::WithOp withOp : allWithOps) {
267  Value shape = withOp.getShape();
268  if (clusters.count(shape) == 0)
269  getClusterFromValue(shape, clusters);
270  }
271  return getOrderedClusters(clusters, funcOp);
272 }
273 
274 // The output of a cluster is the `shape`, and the inputs are the outputs of
275 // operations who are not in `onlyUsedByWithShapes`
276 void OutlineShapeComputationPass::getClusterFromValue(
277  Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
278  DenseSet<Operation *> cluster;
279 
280  DenseSet<Operation *> visited;
281  std::queue<Operation *> queue;
282 
283  // defOp == nullptr means shape is the argument of the func op
284  if (Operation *defOp = shape.getDefiningOp()) {
285  visited.insert(defOp);
286  queue.push(defOp);
287  }
288  while (!queue.empty()) {
289  Operation *op = queue.front();
290  queue.pop();
291  if (onlyUsedByWithShapes.contains(op)) {
292  cluster.insert(op);
293  for (Value inp : op->getOperands()) {
294  Operation *inpDefOp = inp.getDefiningOp();
295  if (nullptr != inpDefOp && !visited.contains(inpDefOp)) {
296  visited.insert(inpDefOp);
297  queue.push(inpDefOp);
298  }
299  }
300  }
301  }
302 
303  clusters[shape] = std::move(cluster);
304 }
305 
306 // Returns whether `op` is a shape.with_shape, or all the users' of `op`
307 // eventually point to the shape operand of shape.with_shape ops
308 bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
309  Operation *op, Value prevOutput) {
310  if (onlyUsedByWithShapes.contains(op))
311  return true;
312 
313  if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
314  return withOp.getShape() == prevOutput;
315 
316  if (op->use_empty())
317  return false;
318 
319  for (Value oup : op->getResults())
320  for (Operation *user : oup.getUsers())
321  if (!calOnlyUsedByWithShapesRecursively(user, oup))
322  return false;
323 
324  onlyUsedByWithShapes.insert(op);
325  return true;
326 }
327 
328 } // namespace
329 
330 std::unique_ptr<OperationPass<ModuleOp>>
332  return std::make_unique<OutlineShapeComputationPass>();
333 }
Block represents an ordered list of Operations.
Definition: Block.h:31
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:108
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:212
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:567
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:403
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:417
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:931
void clear()
Clear out all of the held patterns in this list.
Definition: PatternMatch.h:832
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:173
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
std::unique_ptr< OperationPass< ModuleOp > > createOutlineShapeComputationPass()
Outline the shape computation part by adding shape.func and populate conrresponding mapping infomatio...
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
ShapeMappingAnalysis is used together with OutlineShapeComputationPass to preserve Value and correspo...
llvm::DenseMap< Value, ShapeMappingValue > shapeMapping
ShapeMappingValue works as the value of ShapeMappingAnalysis table, where funcSymbol is the symbol of...
llvm::SmallVector< Value > inputs