MLIR  22.0.0git
WrapFuncInClass.cpp
Go to the documentation of this file.
1 //===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace emitc;
20 
21 namespace mlir {
22 namespace emitc {
23 #define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
24 #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
25 
26 namespace {
27 struct WrapFuncInClassPass
28  : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
29  using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
30  void runOnOperation() override {
31  Operation *rootOp = getOperation();
32 
35 
36  walkAndApplyPatterns(rootOp, std::move(patterns));
37  }
38 };
39 
40 } // namespace
41 } // namespace emitc
42 } // namespace mlir
43 
44 class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
45 public:
47  : OpRewritePattern<emitc::FuncOp>(context) {}
48 
49  LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
50  PatternRewriter &rewriter) const override {
51 
52  auto className = funcOp.getSymNameAttr().str() + "Class";
53  ClassOp newClassOp = ClassOp::create(rewriter, funcOp.getLoc(), className);
54 
56  rewriter.createBlock(&newClassOp.getBody());
57  rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
58 
59  auto argAttrs = funcOp.getArgAttrs();
60  for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
61  StringAttr fieldName =
62  rewriter.getStringAttr("fieldName" + std::to_string(idx));
63 
64  TypeAttr typeAttr = TypeAttr::get(val.getType());
65  fields.push_back({fieldName, typeAttr});
66 
67  FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(),
68  fieldName, typeAttr, nullptr);
69 
70  if (argAttrs && idx < argAttrs->size()) {
71  fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
72  }
73  }
74 
75  rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
76  FunctionType funcType = funcOp.getFunctionType();
77  Location loc = funcOp.getLoc();
78  FuncOp newFuncOp =
79  emitc::FuncOp::create(rewriter, loc, ("execute"), funcType);
80 
81  rewriter.createBlock(&newFuncOp.getBody());
82  newFuncOp.getBody().takeBody(funcOp.getBody());
83 
84  rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
85  std::vector<Value> newArguments;
86  newArguments.reserve(fields.size());
87  for (auto &[fieldName, attr] : fields) {
88  GetFieldOp arg =
89  emitc::GetFieldOp::create(rewriter, loc, attr.getValue(), fieldName);
90  newArguments.push_back(arg);
91  }
92 
93  for (auto [oldArg, newArg] :
94  llvm::zip(newFuncOp.getArguments(), newArguments)) {
95  rewriter.replaceAllUsesWith(oldArg, newArg);
96  }
97 
98  llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
99  if (failed(newFuncOp.eraseArguments(argsToErase)))
100  newFuncOp->emitOpError("failed to erase all arguments using BitVector");
101 
102  rewriter.replaceOp(funcOp, newClassOp);
103  return success();
104  }
105 };
106 
108  patterns.add<WrapFuncInClass>(patterns.getContext());
109 }
static MLIRContext * getContext(OpFoldResult val)
LogicalResult matchAndRewrite(emitc::FuncOp funcOp, PatternRewriter &rewriter) const override
WrapFuncInClass(MLIRContext *context)
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateFuncPatterns(RewritePatternSet &patterns)
Populates 'patterns' with func-related patterns.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314