17 #include "llvm/ADT/DenseSet.h"
18 #include "llvm/Support/Debug.h"
23 #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATIONPASS
24 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
27 #define DEBUG_TYPE "outline-shape-computation"
38 llvm::SmallDenseSet<Value> inputSet;
39 llvm::SmallDenseSet<Operation *> opSet;
41 bool inserted = opSet.insert(op).second;
43 assert(inserted &&
"cluster contains duplicate operations");
47 for (
Value operand : op->getOperands()) {
48 Operation *operandOp = operand.getDefiningOp();
49 if (opSet.contains(operandOp)) {
53 if (inputSet.insert(operand).second)
54 inputs.push_back(operand);
61 std::pair<shape::FuncOp, SmallVector<Value>>
69 shape::FuncOp fnOp = shape::FuncOp::create(b, loc, fnName, fnType);
70 Block *block = fnOp.addEntryBlock();
73 if (cluster.empty()) {
74 bvm.
map(shape, fnOp.getArgument(0));
76 for (
auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
77 bvm.
map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
85 shape::ReturnOp::create(b, loc, fnReturns);
87 return std::make_pair(fnOp, inputs);
94 func::FuncOp funcOp) {
97 for (
const auto &it : clusters) {
98 Value shape = it.first;
101 op2Shapes[cOp].push_back(shape);
108 auto it = op2Shapes.find(op);
109 if (it != op2Shapes.end()) {
110 Operation *cOp = it->first;
111 for (Value shape : it->second)
112 orderedClusters[shape].push_back(cOp);
116 return orderedClusters;
119 void constructShapeFunc(
120 const std::vector<shape::WithOp> &allWithOps,
MLIRContext *context,
125 std::string shapeCalculationNamePrefix =
"shape_cal_";
126 int shapeCalculationNameIdx = 0;
130 for (shape::WithOp withOp : allWithOps) {
131 Value value = withOp.getOperand();
132 Value shape = withOp.getShape();
133 RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.
getType());
134 if (rankedType ==
nullptr)
139 auto it = dynShape2ShapeFunc.find(shape);
140 if (it == dynShape2ShapeFunc.end()) {
141 std::string name = shapeCalculationNamePrefix +
142 std::to_string(shapeCalculationNameIdx++);
145 auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
147 shape::FuncOp shapeFuncOp = pair.first;
148 StringAttr insertedName = symbolTable.
insert(shapeFuncOp);
152 shapeMappingValue.
inputs = inputs;
154 shapeMappingValue = it->second;
156 dynShape2ShapeFunc[shape] = shapeMappingValue;
158 std::make_pair(value, shapeMappingValue));
162 struct OutlineShapeComputationPass
163 :
public impl::OutlineShapeComputationPassBase<
164 OutlineShapeComputationPass> {
166 void runOnOperation()
override;
169 bool calOnlyUsedByWithShapesRecursively(
Operation *op,
Value prevOutput);
171 void getClusterFromValue(
Value shape,
175 constructClustersForEachShape(
const std::vector<shape::WithOp> &allWithOps,
176 func::FuncOp funcOp);
184 LogicalResult matchAndRewrite(tensor::DimOp op,
187 shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getSource());
194 void OutlineShapeComputationPass::runOnOperation() {
195 ModuleOp moduleOp = getOperation();
198 auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
202 markAnalysesPreserved<shape::ShapeMappingAnalysis>();
204 moduleOp.walk([&](func::FuncOp funcOp) {
207 prevPatterns.
insert<TensorDimOpRewriter>(context);
209 return signalPassFailure();
212 onlyUsedByWithShapes.
clear();
214 calOnlyUsedByWithShapesRecursively(op,
nullptr);
217 llvm::dbgs() <<
"onlyUsedByWithShapes table: \n";
218 for (
auto it : onlyUsedByWithShapes)
219 llvm::dbgs() << *it <<
"\n";
223 std::vector<shape::WithOp> allWithOps;
224 funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
227 constructClustersForEachShape(allWithOps, funcOp);
228 constructShapeFunc(allWithOps, context, clusters, symbolTable,
229 dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
231 for (shape::WithOp withOp : allWithOps) {
232 Value value = withOp.getOperand();
234 llvm::make_early_inc_range(withOp.getResult().getUsers())) {
235 if (
auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
246 if (valueOf.getType() == value.
getType())
249 valueOf.setOperand(value);
256 return signalPassFailure();
261 OutlineShapeComputationPass::constructClustersForEachShape(
262 const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
264 for (shape::WithOp withOp : allWithOps) {
265 Value shape = withOp.getShape();
266 if (clusters.count(shape) == 0)
267 getClusterFromValue(shape, clusters);
269 return getOrderedClusters(clusters, funcOp);
274 void OutlineShapeComputationPass::getClusterFromValue(
279 std::queue<Operation *> queue;
283 visited.insert(defOp);
286 while (!queue.empty()) {
289 if (onlyUsedByWithShapes.contains(op)) {
292 Operation *inpDefOp = inp.getDefiningOp();
293 if (
nullptr != inpDefOp && visited.insert(inpDefOp).second)
294 queue.push(inpDefOp);
299 clusters[shape] = std::move(cluster);
304 bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
306 if (onlyUsedByWithShapes.contains(op))
309 if (
auto withOp = llvm::dyn_cast<shape::WithOp>(op))
310 return withOp.getShape() == prevOutput;
317 if (!calOnlyUsedByWithShapesRecursively(user, oup))
320 onlyUsedByWithShapes.insert(op);
Block represents an ordered list of Operations.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void clear()
Clear out all of the held patterns in this list.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
ShapeMappingAnalysis is used together with OutlineShapeComputationPass to preserve Value and correspo...
llvm::DenseMap< Value, ShapeMappingValue > shapeMapping
ShapeMappingValue works as the value of ShapeMappingAnalysis table, where funcSymbol is the symbol of...
llvm::SmallVector< Value > inputs
FlatSymbolRefAttr funcSymbol