17#include "llvm/ADT/DenseSet.h"
18#include "llvm/Support/Debug.h"
23#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATIONPASS
24#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
27#define DEBUG_TYPE "outline-shape-computation"
38 llvm::SmallDenseSet<Value> inputSet;
39 llvm::SmallDenseSet<Operation *> opSet;
41 bool inserted = opSet.insert(op).second;
43 assert(
inserted &&
"cluster contains duplicate operations");
47 for (
Value operand : op->getOperands()) {
48 Operation *operandOp = operand.getDefiningOp();
49 if (opSet.contains(operandOp)) {
53 if (inputSet.insert(operand).second)
54 inputs.push_back(operand);
61std::pair<shape::FuncOp, SmallVector<Value>>
67 ?
b.getFunctionType(
shape.getType(),
shape.getType())
69 shape::FuncOp fnOp = shape::FuncOp::create(
b, loc, fnName, fnType);
70 Block *block = fnOp.addEntryBlock();
71 b.setInsertionPointToEnd(block);
73 if (cluster.empty()) {
76 for (
auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
77 bvm.
map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
85 shape::ReturnOp::create(
b, loc, fnReturns);
87 return std::make_pair(fnOp, inputs);
94 func::FuncOp funcOp) {
97 for (
const auto &it : clusters) {
101 op2Shapes[cOp].push_back(
shape);
108 auto it = op2Shapes.find(op);
109 if (it != op2Shapes.end()) {
112 orderedClusters[
shape].push_back(cOp);
116 return orderedClusters;
119void constructShapeFunc(
120 const std::vector<shape::WithOp> &allWithOps,
MLIRContext *context,
125 std::string shapeCalculationNamePrefix =
"shape_cal_";
126 int shapeCalculationNameIdx = 0;
130 for (shape::WithOp withOp : allWithOps) {
131 Value value = withOp.getOperand();
133 RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.
getType());
134 if (rankedType ==
nullptr)
139 auto it = dynShape2ShapeFunc.find(
shape);
140 if (it == dynShape2ShapeFunc.end()) {
141 std::string name = shapeCalculationNamePrefix +
142 std::to_string(shapeCalculationNameIdx++);
144 builder.setInsertionPointAfter(funcOp);
145 auto pair = createFuncFromCluster(builder, cluster,
shape, name, loc);
147 shape::FuncOp shapeFuncOp = pair.first;
148 StringAttr insertedName = symbolTable.
insert(shapeFuncOp);
152 shapeMappingValue.
inputs = inputs;
154 shapeMappingValue = it->second;
156 dynShape2ShapeFunc[
shape] = shapeMappingValue;
158 std::make_pair(value, shapeMappingValue));
162struct OutlineShapeComputationPass
164 OutlineShapeComputationPass> {
166 void runOnOperation()
override;
169 bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
171 void getClusterFromValue(Value shape,
175 constructClustersForEachShape(
const std::vector<shape::WithOp> &allWithOps,
176 func::FuncOp funcOp);
182 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
184 LogicalResult matchAndRewrite(tensor::DimOp op,
185 PatternRewriter &rewriter)
const override {
187 shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getSource());
194void OutlineShapeComputationPass::runOnOperation() {
195 ModuleOp moduleOp = getOperation();
196 SymbolTable symbolTable(moduleOp);
198 auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
202 markAnalysesPreserved<shape::ShapeMappingAnalysis>();
204 moduleOp.walk([&](func::FuncOp funcOp) {
205 MLIRContext *context = funcOp.getContext();
206 RewritePatternSet prevPatterns(context);
207 prevPatterns.insert<TensorDimOpRewriter>(context);
209 return signalPassFailure();
212 onlyUsedByWithShapes.clear();
213 funcOp.walk([&](Operation *op) {
214 calOnlyUsedByWithShapesRecursively(op,
nullptr);
217 llvm::dbgs() <<
"onlyUsedByWithShapes table: \n";
218 for (
auto it : onlyUsedByWithShapes)
219 llvm::dbgs() << *it <<
"\n";
223 std::vector<shape::WithOp> allWithOps;
224 funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
227 constructClustersForEachShape(allWithOps, funcOp);
228 constructShapeFunc(allWithOps, context, clusters, symbolTable,
229 dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
231 for (shape::WithOp withOp : allWithOps) {
232 Value value = withOp.getOperand();
233 for (Operation *user :
234 llvm::make_early_inc_range(withOp.getResult().getUsers())) {
235 if (
auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
246 if (valueOf.getType() == value.
getType())
247 valueOf.replaceAllUsesWith(value);
249 valueOf.setOperand(value);
256 return signalPassFailure();
261OutlineShapeComputationPass::constructClustersForEachShape(
262 const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
264 for (shape::WithOp withOp : allWithOps) {
265 Value shape = withOp.getShape();
266 if (clusters.count(shape) == 0)
267 getClusterFromValue(shape, clusters);
269 return getOrderedClusters(clusters, funcOp);
274void OutlineShapeComputationPass::getClusterFromValue(
279 std::queue<Operation *> queue;
283 visited.insert(defOp);
286 while (!queue.empty()) {
287 Operation *op = queue.front();
289 if (onlyUsedByWithShapes.contains(op)) {
292 Operation *inpDefOp = inp.getDefiningOp();
293 if (
nullptr != inpDefOp && visited.insert(inpDefOp).second)
294 queue.push(inpDefOp);
299 clusters[shape] = std::move(cluster);
304bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
305 Operation *op, Value prevOutput) {
306 if (onlyUsedByWithShapes.contains(op))
309 if (
auto withOp = llvm::dyn_cast<shape::WithOp>(op))
310 return withOp.getShape() == prevOutput;
316 for (Operation *user : oup.getUsers())
317 if (!calOnlyUsedByWithShapesRecursively(user, oup))
320 onlyUsedByWithShapes.insert(op);
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Block represents an ordered list of Operations.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
ShapeMappingAnalysis is used together with OutlineShapeComputationPass to preserve Value and correspo...
llvm::DenseMap< Value, ShapeMappingValue > shapeMapping
ShapeMappingValue works as the value of ShapeMappingAnalysis table, where funcSymbol is the symbol of...
llvm::SmallVector< Value > inputs
FlatSymbolRefAttr funcSymbol