MLIR  19.0.0git
OutlineShapeComputation.cpp
Go to the documentation of this file.
1 //====----- OutlineShapeComputation.cpp -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/Pass/Pass.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/Support/Debug.h"
21 #include <queue>
22 #include <unordered_set>
23 #include <vector>
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
27 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 #define DEBUG_TYPE "outline-shape-computation"
31 
32 using namespace mlir;
33 
34 namespace {
35 
36 // A Value is an input of the cluster if it is an operand of an operation in the
37 // cluster and its defining operation is not in the cluster.
39 getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
40  SmallVector<Value, 4> inputs;
41  llvm::SmallDenseSet<Value> inputSet;
42  llvm::SmallDenseSet<Operation *> opSet;
43  for (Operation *op : cluster) {
44  bool inserted = opSet.insert(op).second;
45  (void)inserted;
46  assert(inserted && "cluster contains duplicate operations");
47  }
48 
49  for (Operation *op : cluster) {
50  for (Value operand : op->getOperands()) {
51  Operation *operandOp = operand.getDefiningOp();
52  if (opSet.contains(operandOp)) {
53  // Skip if defining op is in the cluster.
54  continue;
55  }
56  if (inputSet.insert(operand).second)
57  inputs.push_back(operand);
58  }
59  }
60  return inputs;
61 }
62 
63 // Create a shape.func representing the shape computation for `shape`.
64 std::pair<shape::FuncOp, SmallVector<Value>>
65 createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
66  Value shape, StringRef fnName, Location loc) {
67  SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
68  auto fnType =
69  cluster.empty()
70  ? b.getFunctionType(shape.getType(), shape.getType())
71  : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
72  shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
73  Block *block = fnOp.addEntryBlock();
74  b.setInsertionPoint(block, block->end());
75  IRMapping bvm;
76  if (cluster.empty()) {
77  bvm.map(shape, fnOp.getArgument(0));
78  } else {
79  for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
80  bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
81  }
82 
83  for (Operation *op : cluster)
84  b.clone(*op, bvm);
86  fnReturns.push_back(bvm.lookupOrDefault(shape));
87 
88  b.create<shape::ReturnOp>(loc, fnReturns);
89  fnOp.setPrivate();
90  return std::make_pair(fnOp, inputs);
91 }
92 
93 // The operations in the cluster might be unsorted, which could be inconvenient
94 // when creating shape.func op.
96 getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
97  func::FuncOp funcOp) {
98  // Compute all clusters that each operation is in
100  for (const auto &it : clusters) {
101  Value shape = it.first;
102  const DenseSet<Operation *> &cluster = it.second;
103  for (Operation *cOp : cluster)
104  op2Shapes[cOp].push_back(shape);
105  }
106 
107  // Iterate through all operations in order. Get all the clusters `cOp` belongs
108  // to and construct the new ordered cluster as it traverses.
110  funcOp.walk([&](Operation *op) {
111  auto it = op2Shapes.find(op);
112  if (it != op2Shapes.end()) {
113  Operation *cOp = it->first;
114  for (Value shape : it->second)
115  orderedClusters[shape].push_back(cOp);
116  }
117  });
118 
119  return orderedClusters;
120 }
121 
122 void constructShapeFunc(
123  const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
125  SymbolTable &symbolTable,
126  DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
127  func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
128  std::string shapeCalculationNamePrefix = "shape_cal_";
129  int shapeCalculationNameIdx = 0;
130  OpBuilder builder(context);
131 
132  // Construct a shape function
133  for (shape::WithOp withOp : allWithOps) {
134  Value value = withOp.getOperand();
135  Value shape = withOp.getShape();
136  RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
137  if (rankedType == nullptr)
138  continue;
139 
140  const SmallVector<Operation *, 8> &cluster = clusters[shape];
141  shape::ShapeMappingValue shapeMappingValue;
142  auto it = dynShape2ShapeFunc.find(shape);
143  if (it == dynShape2ShapeFunc.end()) {
144  std::string name = shapeCalculationNamePrefix +
145  std::to_string(shapeCalculationNameIdx++);
146  Location loc = value.getLoc();
147  builder.setInsertionPointAfter(funcOp);
148  auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
149  const SmallVector<Value> &inputs = pair.second;
150  shape::FuncOp shapeFuncOp = pair.first;
151  StringAttr insertedName = symbolTable.insert(shapeFuncOp);
152  auto symbol = FlatSymbolRefAttr::get(context, insertedName);
153 
154  shapeMappingValue.funcSymbol = symbol;
155  shapeMappingValue.inputs = inputs;
156  } else {
157  shapeMappingValue = it->second;
158  }
159  dynShape2ShapeFunc[shape] = shapeMappingValue;
160  shapeMappingAnalysis.shapeMapping.insert(
161  std::make_pair(value, shapeMappingValue));
162  }
163 }
164 
165 struct OutlineShapeComputationPass
166  : public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
167 
168  void runOnOperation() override;
169 
170 private:
171  bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
172 
173  void getClusterFromValue(Value shape,
174  DenseMap<Value, DenseSet<Operation *>> &clusters);
175 
177  constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
178  func::FuncOp funcOp);
179 
180  DenseSet<Operation *> onlyUsedByWithShapes;
181 };
182 
183 class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
185 
186  LogicalResult matchAndRewrite(tensor::DimOp op,
187  PatternRewriter &rewriter) const override {
188  auto shapeOf =
189  rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
190  rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
191  op.getIndex());
192  return success();
193  }
194 };
195 
196 void OutlineShapeComputationPass::runOnOperation() {
197  ModuleOp moduleOp = getOperation();
198  SymbolTable symbolTable(moduleOp);
200  auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
201  // TODO: This is as we populate this analysis during a pass that mutates. This
202  // pass currently requires 1 single module being compiled.
203  shapeMappingAnalysis.shapeMapping.clear();
204  markAnalysesPreserved<shape::ShapeMappingAnalysis>();
205 
206  moduleOp.walk([&](func::FuncOp funcOp) {
207  MLIRContext *context = funcOp.getContext();
208  RewritePatternSet prevPatterns(context);
209  prevPatterns.insert<TensorDimOpRewriter>(context);
210  if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns))))
211  return signalPassFailure();
212 
213  // initialize class member `onlyUsedByWithShapes`
214  onlyUsedByWithShapes.clear();
215  funcOp.walk([&](Operation *op) {
216  calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
217  });
218  LLVM_DEBUG({
219  llvm::dbgs() << "onlyUsedByWithShapes table: \n";
220  for (auto it : onlyUsedByWithShapes)
221  llvm::dbgs() << *it << "\n";
222  });
223 
224  // collect all the shape.with_shape ops.
225  std::vector<shape::WithOp> allWithOps;
226  funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
227 
229  constructClustersForEachShape(allWithOps, funcOp);
230  constructShapeFunc(allWithOps, context, clusters, symbolTable,
231  dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
232 
233  for (shape::WithOp withOp : allWithOps) {
234  Value value = withOp.getOperand();
235  for (Operation *user :
236  llvm::make_early_inc_range(withOp.getResult().getUsers())) {
237  if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
238  // For pattern like
239  // %1 = shape.with_shape %arg1, %0
240  // %2 = shape.value_of %1
241  // because shape.value doesn't care the shape, the shape.with_shape is
242  // redundant.
243  // If type of %arg1 and %2 has same type, just
244  // replaced %2 with %arg1.
245  // If type of %arg1 has different type like !shape.value_shape,
246  // transform into
247  // %2 = shape.value_of %arg1
248  if (valueOf.getType() == value.getType())
249  valueOf.replaceAllUsesWith(value);
250  else
251  valueOf.setOperand(value);
252  }
253  }
254  }
255 
256  // Apply patterns, note this also performs DCE.
257  if (failed(applyPatternsAndFoldGreedily(funcOp, {})))
258  return signalPassFailure();
259  });
260 }
261 
263 OutlineShapeComputationPass::constructClustersForEachShape(
264  const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
266  for (shape::WithOp withOp : allWithOps) {
267  Value shape = withOp.getShape();
268  if (clusters.count(shape) == 0)
269  getClusterFromValue(shape, clusters);
270  }
271  return getOrderedClusters(clusters, funcOp);
272 }
273 
274 // The output of a cluster is the `shape`, and the inputs are the outputs of
275 // operations who are not in `onlyUsedByWithShapes`
276 void OutlineShapeComputationPass::getClusterFromValue(
277  Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
278  DenseSet<Operation *> cluster;
279 
280  DenseSet<Operation *> visited;
281  std::queue<Operation *> queue;
282 
283  // defOp == nullptr means shape is the argument of the func op
284  if (Operation *defOp = shape.getDefiningOp()) {
285  visited.insert(defOp);
286  queue.push(defOp);
287  }
288  while (!queue.empty()) {
289  Operation *op = queue.front();
290  queue.pop();
291  if (onlyUsedByWithShapes.contains(op)) {
292  cluster.insert(op);
293  for (Value inp : op->getOperands()) {
294  Operation *inpDefOp = inp.getDefiningOp();
295  if (nullptr != inpDefOp && !visited.contains(inpDefOp)) {
296  visited.insert(inpDefOp);
297  queue.push(inpDefOp);
298  }
299  }
300  }
301  }
302 
303  clusters[shape] = std::move(cluster);
304 }
305 
306 // Returns whether `op` is a shape.with_shape, or all the users' of `op`
307 // eventually point to the shape operand of shape.with_shape ops
308 bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
309  Operation *op, Value prevOutput) {
310  if (onlyUsedByWithShapes.contains(op))
311  return true;
312 
313  if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
314  return withOp.getShape() == prevOutput;
315 
316  if (op->use_empty())
317  return false;
318 
319  for (Value oup : op->getResults())
320  for (Operation *user : oup.getUsers())
321  if (!calOnlyUsedByWithShapesRecursively(user, oup))
322  return false;
323 
324  onlyUsedByWithShapes.insert(op);
325  return true;
326 }
327 
328 } // namespace
329 
330 std::unique_ptr<OperationPass<ModuleOp>>
332  return std::make_unique<OutlineShapeComputationPass>();
333 }
Block represents an ordered list of Operations.
Definition: Block.h:30
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:555
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
void clear()
Clear out all of the held patterns in this list.
Definition: PatternMatch.h:831
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Definition: Value.h:173
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Include the generated interface declarations.
std::unique_ptr< OperationPass< ModuleOp > > createOutlineShapeComputationPass()
Outline the shape computation part by adding shape.func and populate conrresponding mapping infomatio...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
ShapeMappingAnalysis is used together with OutlineShapeComputationPass to preserve Value and correspo...
llvm::DenseMap< Value, ShapeMappingValue > shapeMapping
ShapeMappingValue works as the value of ShapeMappingAnalysis table, where funcSymbol is the symbol of...
llvm::SmallVector< Value > inputs