19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/Support/Debug.h"
22 #include <unordered_set>
26 #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATIONPASS
27 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
30 #define DEBUG_TYPE "outline-shape-computation"
41 llvm::SmallDenseSet<Value> inputSet;
42 llvm::SmallDenseSet<Operation *> opSet;
44 bool inserted = opSet.insert(op).second;
46 assert(inserted &&
"cluster contains duplicate operations");
50 for (
Value operand : op->getOperands()) {
51 Operation *operandOp = operand.getDefiningOp();
52 if (opSet.contains(operandOp)) {
56 if (inputSet.insert(operand).second)
57 inputs.push_back(operand);
64 std::pair<shape::FuncOp, SmallVector<Value>>
72 shape::FuncOp fnOp = b.
create<shape::FuncOp>(loc, fnName, fnType);
73 Block *block = fnOp.addEntryBlock();
76 if (cluster.empty()) {
77 bvm.
map(shape, fnOp.getArgument(0));
79 for (
auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
80 bvm.
map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
88 b.
create<shape::ReturnOp>(loc, fnReturns);
90 return std::make_pair(fnOp, inputs);
97 func::FuncOp funcOp) {
100 for (
const auto &it : clusters) {
101 Value shape = it.first;
104 op2Shapes[cOp].push_back(shape);
111 auto it = op2Shapes.find(op);
112 if (it != op2Shapes.end()) {
113 Operation *cOp = it->first;
114 for (Value shape : it->second)
115 orderedClusters[shape].push_back(cOp);
119 return orderedClusters;
122 void constructShapeFunc(
123 const std::vector<shape::WithOp> &allWithOps,
MLIRContext *context,
128 std::string shapeCalculationNamePrefix =
"shape_cal_";
129 int shapeCalculationNameIdx = 0;
133 for (shape::WithOp withOp : allWithOps) {
134 Value value = withOp.getOperand();
135 Value shape = withOp.getShape();
136 RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.
getType());
137 if (rankedType ==
nullptr)
142 auto it = dynShape2ShapeFunc.find(shape);
143 if (it == dynShape2ShapeFunc.end()) {
144 std::string name = shapeCalculationNamePrefix +
145 std::to_string(shapeCalculationNameIdx++);
148 auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
150 shape::FuncOp shapeFuncOp = pair.first;
151 StringAttr insertedName = symbolTable.
insert(shapeFuncOp);
155 shapeMappingValue.
inputs = inputs;
157 shapeMappingValue = it->second;
159 dynShape2ShapeFunc[shape] = shapeMappingValue;
161 std::make_pair(value, shapeMappingValue));
165 struct OutlineShapeComputationPass
166 :
public impl::OutlineShapeComputationPassBase<
167 OutlineShapeComputationPass> {
169 void runOnOperation()
override;
172 bool calOnlyUsedByWithShapesRecursively(
Operation *op,
Value prevOutput);
174 void getClusterFromValue(
Value shape,
178 constructClustersForEachShape(
const std::vector<shape::WithOp> &allWithOps,
179 func::FuncOp funcOp);
187 LogicalResult matchAndRewrite(tensor::DimOp op,
190 rewriter.
create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
197 void OutlineShapeComputationPass::runOnOperation() {
198 ModuleOp moduleOp = getOperation();
201 auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
205 markAnalysesPreserved<shape::ShapeMappingAnalysis>();
207 moduleOp.walk([&](func::FuncOp funcOp) {
210 prevPatterns.
insert<TensorDimOpRewriter>(context);
212 return signalPassFailure();
215 onlyUsedByWithShapes.
clear();
217 calOnlyUsedByWithShapesRecursively(op,
nullptr);
220 llvm::dbgs() <<
"onlyUsedByWithShapes table: \n";
221 for (
auto it : onlyUsedByWithShapes)
222 llvm::dbgs() << *it <<
"\n";
226 std::vector<shape::WithOp> allWithOps;
227 funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
230 constructClustersForEachShape(allWithOps, funcOp);
231 constructShapeFunc(allWithOps, context, clusters, symbolTable,
232 dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
234 for (shape::WithOp withOp : allWithOps) {
235 Value value = withOp.getOperand();
237 llvm::make_early_inc_range(withOp.getResult().getUsers())) {
238 if (
auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
249 if (valueOf.getType() == value.
getType())
252 valueOf.setOperand(value);
259 return signalPassFailure();
264 OutlineShapeComputationPass::constructClustersForEachShape(
265 const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
267 for (shape::WithOp withOp : allWithOps) {
268 Value shape = withOp.getShape();
269 if (clusters.count(shape) == 0)
270 getClusterFromValue(shape, clusters);
272 return getOrderedClusters(clusters, funcOp);
277 void OutlineShapeComputationPass::getClusterFromValue(
282 std::queue<Operation *> queue;
286 visited.insert(defOp);
289 while (!queue.empty()) {
292 if (onlyUsedByWithShapes.contains(op)) {
295 Operation *inpDefOp = inp.getDefiningOp();
296 if (
nullptr != inpDefOp && visited.insert(inpDefOp).second)
297 queue.push(inpDefOp);
302 clusters[shape] = std::move(cluster);
307 bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
309 if (onlyUsedByWithShapes.contains(op))
312 if (
auto withOp = llvm::dyn_cast<shape::WithOp>(op))
313 return withOp.getShape() == prevOutput;
320 if (!calOnlyUsedByWithShapesRecursively(user, oup))
323 onlyUsedByWithShapes.insert(op);
Block represents an ordered list of Operations.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void clear()
Clear out all of the held patterns in this list.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
ShapeMappingAnalysis is used together with OutlineShapeComputationPass to preserve Value and correspo...
llvm::DenseMap< Value, ShapeMappingValue > shapeMapping
ShapeMappingValue works as the value of ShapeMappingAnalysis table, where funcSymbol is the symbol of...
llvm::SmallVector< Value > inputs
FlatSymbolRefAttr funcSymbol