19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/Support/Debug.h"
22 #include <unordered_set>
26 #define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
27 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
30 #define DEBUG_TYPE "outline-shape-computation"
41 llvm::SmallDenseSet<Value> inputSet;
42 llvm::SmallDenseSet<Operation *> opSet;
44 bool inserted = opSet.insert(op).second;
46 assert(inserted &&
"cluster contains duplicate operations");
50 for (
Value operand : op->getOperands()) {
51 Operation *operandOp = operand.getDefiningOp();
52 if (opSet.contains(operandOp)) {
56 if (inputSet.insert(operand).second)
57 inputs.push_back(operand);
64 std::pair<shape::FuncOp, SmallVector<Value>>
72 shape::FuncOp fnOp = b.
create<shape::FuncOp>(loc, fnName, fnType);
73 Block *block = fnOp.addEntryBlock();
76 if (cluster.empty()) {
77 bvm.
map(shape, fnOp.getArgument(0));
79 for (
auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
80 bvm.
map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
88 b.
create<shape::ReturnOp>(loc, fnReturns);
90 return std::make_pair(fnOp, inputs);
97 func::FuncOp funcOp) {
100 for (
const auto &it : clusters) {
101 Value shape = it.first;
104 op2Shapes[cOp].push_back(shape);
111 auto it = op2Shapes.find(op);
112 if (it != op2Shapes.end()) {
113 Operation *cOp = it->first;
114 for (Value shape : it->second)
115 orderedClusters[shape].push_back(cOp);
119 return orderedClusters;
122 void constructShapeFunc(
123 const std::vector<shape::WithOp> &allWithOps,
MLIRContext *context,
128 std::string shapeCalculationNamePrefix =
"shape_cal_";
129 int shapeCalculationNameIdx = 0;
133 for (shape::WithOp withOp : allWithOps) {
134 Value value = withOp.getOperand();
135 Value shape = withOp.getShape();
136 RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.
getType());
137 if (rankedType ==
nullptr)
142 auto it = dynShape2ShapeFunc.find(shape);
143 if (it == dynShape2ShapeFunc.end()) {
144 std::string name = shapeCalculationNamePrefix +
145 std::to_string(shapeCalculationNameIdx++);
148 auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
150 shape::FuncOp shapeFuncOp = pair.first;
151 StringAttr insertedName = symbolTable.
insert(shapeFuncOp);
155 shapeMappingValue.
inputs = inputs;
157 shapeMappingValue = it->second;
159 dynShape2ShapeFunc[shape] = shapeMappingValue;
161 std::make_pair(value, shapeMappingValue));
165 struct OutlineShapeComputationPass
166 :
public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
168 void runOnOperation()
override;
171 bool calOnlyUsedByWithShapesRecursively(
Operation *op,
Value prevOutput);
173 void getClusterFromValue(
Value shape,
177 constructClustersForEachShape(
const std::vector<shape::WithOp> &allWithOps,
178 func::FuncOp funcOp);
186 LogicalResult matchAndRewrite(tensor::DimOp op,
189 rewriter.
create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
196 void OutlineShapeComputationPass::runOnOperation() {
197 ModuleOp moduleOp = getOperation();
200 auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
204 markAnalysesPreserved<shape::ShapeMappingAnalysis>();
206 moduleOp.walk([&](func::FuncOp funcOp) {
209 prevPatterns.
insert<TensorDimOpRewriter>(context);
211 return signalPassFailure();
214 onlyUsedByWithShapes.
clear();
216 calOnlyUsedByWithShapesRecursively(op,
nullptr);
219 llvm::dbgs() <<
"onlyUsedByWithShapes table: \n";
220 for (
auto it : onlyUsedByWithShapes)
221 llvm::dbgs() << *it <<
"\n";
225 std::vector<shape::WithOp> allWithOps;
226 funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
229 constructClustersForEachShape(allWithOps, funcOp);
230 constructShapeFunc(allWithOps, context, clusters, symbolTable,
231 dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
233 for (shape::WithOp withOp : allWithOps) {
234 Value value = withOp.getOperand();
236 llvm::make_early_inc_range(withOp.getResult().getUsers())) {
237 if (
auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
248 if (valueOf.getType() == value.
getType())
251 valueOf.setOperand(value);
258 return signalPassFailure();
263 OutlineShapeComputationPass::constructClustersForEachShape(
264 const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
266 for (shape::WithOp withOp : allWithOps) {
267 Value shape = withOp.getShape();
268 if (clusters.count(shape) == 0)
269 getClusterFromValue(shape, clusters);
271 return getOrderedClusters(clusters, funcOp);
276 void OutlineShapeComputationPass::getClusterFromValue(
281 std::queue<Operation *> queue;
285 visited.insert(defOp);
288 while (!queue.empty()) {
291 if (onlyUsedByWithShapes.contains(op)) {
294 Operation *inpDefOp = inp.getDefiningOp();
295 if (
nullptr != inpDefOp && visited.insert(inpDefOp).second)
296 queue.push(inpDefOp);
301 clusters[shape] = std::move(cluster);
306 bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
308 if (onlyUsedByWithShapes.contains(op))
311 if (
auto withOp = llvm::dyn_cast<shape::WithOp>(op))
312 return withOp.getShape() == prevOutput;
319 if (!calOnlyUsedByWithShapesRecursively(user, oup))
322 onlyUsedByWithShapes.insert(op);
328 std::unique_ptr<OperationPass<ModuleOp>>
330 return std::make_unique<OutlineShapeComputationPass>();
Block represents an ordered list of Operations.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void clear()
Clear out all of the held patterns in this list.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesWith(Value newValue)
Replace all uses of 'this' value with the new value, updating anything in the IR that uses 'this' to ...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< OperationPass< ModuleOp > > createOutlineShapeComputationPass()
Outline the shape computation part by adding shape.func and populate conrresponding mapping infomatio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
ShapeMappingAnalysis is used together with OutlineShapeComputationPass to preserve Value and correspo...
llvm::DenseMap< Value, ShapeMappingValue > shapeMapping
ShapeMappingValue works as the value of ShapeMappingAnalysis table, where funcSymbol is the symbol of...
llvm::SmallVector< Value > inputs
FlatSymbolRefAttr funcSymbol