MLIR  21.0.0git
WrapFuncInClass.cpp
Go to the documentation of this file.
1 //===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace emitc;
20 
21 namespace mlir {
22 namespace emitc {
23 #define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
24 #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
25 
26 namespace {
27 struct WrapFuncInClassPass
28  : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
29  using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
30  void runOnOperation() override {
31  Operation *rootOp = getOperation();
32 
34  populateFuncPatterns(patterns, namedAttribute);
35 
36  walkAndApplyPatterns(rootOp, std::move(patterns));
37  }
38 };
39 
40 } // namespace
41 } // namespace emitc
42 } // namespace mlir
43 
44 class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
45 public:
46  WrapFuncInClass(MLIRContext *context, StringRef attrName)
47  : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
48 
49  LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
50  PatternRewriter &rewriter) const override {
51 
52  auto className = funcOp.getSymNameAttr().str() + "Class";
53  ClassOp newClassOp = rewriter.create<ClassOp>(funcOp.getLoc(), className);
54 
56  rewriter.createBlock(&newClassOp.getBody());
57  rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
58 
59  auto argAttrs = funcOp.getArgAttrs();
60  for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
61  StringAttr fieldName;
62  Attribute argAttr = nullptr;
63 
64  fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
65  if (argAttrs && idx < argAttrs->size())
66  argAttr = (*argAttrs)[idx];
67 
68  TypeAttr typeAttr = TypeAttr::get(val.getType());
69  fields.push_back({fieldName, typeAttr});
70  rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
71  argAttr);
72  }
73 
74  rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
75  FunctionType funcType = funcOp.getFunctionType();
76  Location loc = funcOp.getLoc();
77  FuncOp newFuncOp =
78  rewriter.create<emitc::FuncOp>(loc, ("execute"), funcType);
79 
80  rewriter.createBlock(&newFuncOp.getBody());
81  newFuncOp.getBody().takeBody(funcOp.getBody());
82 
83  rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
84  std::vector<Value> newArguments;
85  newArguments.reserve(fields.size());
86  for (auto &[fieldName, attr] : fields) {
87  GetFieldOp arg =
88  rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
89  newArguments.push_back(arg);
90  }
91 
92  for (auto [oldArg, newArg] :
93  llvm::zip(newFuncOp.getArguments(), newArguments)) {
94  rewriter.replaceAllUsesWith(oldArg, newArg);
95  }
96 
97  llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
98  if (failed(newFuncOp.eraseArguments(argsToErase)))
99  newFuncOp->emitOpError("failed to erase all arguments using BitVector");
100 
101  rewriter.replaceOp(funcOp, newClassOp);
102  return success();
103  }
104 
105 private:
106  StringRef attributeName;
107 };
108 
110  StringRef namedAttribute) {
111  patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute);
112 }
static MLIRContext * getContext(OpFoldResult val)
WrapFuncInClass(MLIRContext *context, StringRef attrName)
LogicalResult matchAndRewrite(emitc::FuncOp funcOp, PatternRewriter &rewriter) const override
Attributes are known-constant values of operations.
Definition: Attributes.h:25
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:601
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateFuncPatterns(RewritePatternSet &patterns, StringRef namedAttribute)
Populates 'patterns' with func-related patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314