MLIR 22.0.0git
WrapFuncInClass.cpp
Go to the documentation of this file.
1//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Attributes.h"
13#include "mlir/IR/Builders.h"
17
18using namespace mlir;
19using namespace emitc;
20
21namespace mlir {
22namespace emitc {
23#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
24#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
25
26namespace {
27struct WrapFuncInClassPass
28 : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
29 using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
30 void runOnOperation() override {
31 Operation *rootOp = getOperation();
32
35
36 walkAndApplyPatterns(rootOp, std::move(patterns));
37 }
38};
39
40} // namespace
41} // namespace emitc
42} // namespace mlir
43
44class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
45public:
47 : OpRewritePattern<emitc::FuncOp>(context) {}
48
49 LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
50 PatternRewriter &rewriter) const override {
51
52 auto className = funcOp.getSymNameAttr().str() + "Class";
53 ClassOp newClassOp = ClassOp::create(rewriter, funcOp.getLoc(), className);
54
56 rewriter.createBlock(&newClassOp.getBody());
57 rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
58
59 auto argAttrs = funcOp.getArgAttrs();
60 for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
61 StringAttr fieldName =
62 rewriter.getStringAttr("fieldName" + std::to_string(idx));
63
64 TypeAttr typeAttr = TypeAttr::get(val.getType());
65 fields.push_back({fieldName, typeAttr});
66
67 FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(),
68 fieldName, typeAttr, nullptr);
69
70 if (argAttrs && idx < argAttrs->size()) {
71 fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
72 }
73 }
74
75 rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
76 FunctionType funcType = funcOp.getFunctionType();
77 Location loc = funcOp.getLoc();
78 FuncOp newFuncOp =
79 emitc::FuncOp::create(rewriter, loc, ("execute"), funcType);
80
81 rewriter.createBlock(&newFuncOp.getBody());
82 newFuncOp.getBody().takeBody(funcOp.getBody());
83
84 rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
85 std::vector<Value> newArguments;
86 newArguments.reserve(fields.size());
87 for (auto &[fieldName, attr] : fields) {
88 GetFieldOp arg =
89 emitc::GetFieldOp::create(rewriter, loc, attr.getValue(), fieldName);
90 newArguments.push_back(arg);
91 }
92
93 for (auto [oldArg, newArg] :
94 llvm::zip(newFuncOp.getArguments(), newArguments)) {
95 rewriter.replaceAllUsesWith(oldArg, newArg);
96 }
97
98 llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
99 if (failed(newFuncOp.eraseArguments(argsToErase)))
100 newFuncOp->emitOpError("failed to erase all arguments using BitVector");
102 rewriter.replaceOp(funcOp, newClassOp);
103 return success();
104 }
return success()
b getContext())
LogicalResult matchAndRewrite(emitc::FuncOp funcOp, PatternRewriter &rewriter) const override
WrapFuncInClass(MLIRContext *context)
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void populateFuncPatterns(RewritePatternSet &patterns)
Populates 'patterns' with func-related patterns.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})