MLIR 22.0.0git
FuncOps.cpp
Go to the documentation of this file.
1//===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Matchers.h"
20#include "mlir/IR/Value.h"
23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/MapVector.h"
25#include "llvm/ADT/STLExtras.h"
26
27#include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
28
29using namespace mlir;
30using namespace mlir::func;
31
32//===----------------------------------------------------------------------===//
33// FuncDialect
34//===----------------------------------------------------------------------===//
35
36void FuncDialect::initialize() {
37 addOperations<
38#define GET_OP_LIST
39#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
40 >();
41 declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>();
42 declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
43 declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
44 declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
45 FuncOp, ReturnOp>();
46}
47
48/// Materialize a single constant operation from a given attribute value with
49/// the desired resultant type.
50Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
51 Type type, Location loc) {
52 if (ConstantOp::isBuildableWith(value, type))
53 return ConstantOp::create(builder, loc, type,
54 llvm::cast<FlatSymbolRefAttr>(value));
55 return nullptr;
56}
57
58//===----------------------------------------------------------------------===//
59// CallOp
60//===----------------------------------------------------------------------===//
61
62LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
63 // Check that the callee attribute was specified.
64 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
65 if (!fnAttr)
66 return emitOpError("requires a 'callee' symbol reference attribute");
67 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
68 if (!fn)
69 return emitOpError() << "'" << fnAttr.getValue()
70 << "' does not reference a valid function";
71
72 // Verify that the operand and result types match the callee.
73 auto fnType = fn.getFunctionType();
74 if (fnType.getNumInputs() != getNumOperands())
75 return emitOpError("incorrect number of operands for callee");
76
77 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
78 if (getOperand(i).getType() != fnType.getInput(i))
79 return emitOpError("operand type mismatch: expected operand type ")
80 << fnType.getInput(i) << ", but provided "
81 << getOperand(i).getType() << " for operand number " << i;
82
83 if (fnType.getNumResults() != getNumResults())
84 return emitOpError("incorrect number of results for callee");
85
86 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
87 if (getResult(i).getType() != fnType.getResult(i)) {
88 auto diag = emitOpError("result type mismatch at index ") << i;
89 diag.attachNote() << " op result types: " << getResultTypes();
90 diag.attachNote() << "function result types: " << fnType.getResults();
91 return diag;
92 }
93
94 return success();
95}
96
97FunctionType CallOp::getCalleeType() {
98 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
99}
100
101//===----------------------------------------------------------------------===//
102// CallIndirectOp
103//===----------------------------------------------------------------------===//
104
105/// Fold indirect calls that have a constant function as the callee operand.
106LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
107 PatternRewriter &rewriter) {
108 // Check that the callee is a constant callee.
109 SymbolRefAttr calledFn;
110 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
111 return failure();
112
113 // Replace with a direct call.
114 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
115 indirectCall.getResultTypes(),
116 indirectCall.getArgOperands());
117 return success();
118}
119
120//===----------------------------------------------------------------------===//
121// ConstantOp
122//===----------------------------------------------------------------------===//
123
124LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
125 StringRef fnName = getValue();
126 Type type = getType();
127
128 // Try to find the referenced function.
129 auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(
130 this->getOperation(), StringAttr::get(getContext(), fnName));
131 if (!fn)
132 return emitOpError() << "reference to undefined function '" << fnName
133 << "'";
134
135 // Check that the referenced function has the correct type.
136 if (fn.getFunctionType() != type)
137 return emitOpError("reference to function with mismatched type");
138
139 return success();
140}
141
142OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
143 return getValueAttr();
144}
145
146void ConstantOp::getAsmResultNames(
147 function_ref<void(Value, StringRef)> setNameFn) {
148 setNameFn(getResult(), "f");
149}
150
151bool ConstantOp::isBuildableWith(Attribute value, Type type) {
152 return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
153}
154
155//===----------------------------------------------------------------------===//
156// FuncOp
157//===----------------------------------------------------------------------===//
158
159FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
161 OpBuilder builder(location->getContext());
162 OperationState state(location, getOperationName());
163 FuncOp::build(builder, state, name, type, attrs);
164 return cast<FuncOp>(Operation::create(state));
165}
166FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
168 SmallVector<NamedAttribute, 8> attrRef(attrs);
169 return create(location, name, type, llvm::ArrayRef(attrRef));
170}
171FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
173 ArrayRef<DictionaryAttr> argAttrs) {
174 FuncOp func = create(location, name, type, attrs);
175 func.setAllArgAttrs(argAttrs);
176 return func;
177}
178
179void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
180 FunctionType type, ArrayRef<NamedAttribute> attrs,
181 ArrayRef<DictionaryAttr> argAttrs) {
183 builder.getStringAttr(name));
184 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
185 state.attributes.append(attrs.begin(), attrs.end());
186 state.addRegion();
187
188 if (argAttrs.empty())
189 return;
190 assert(type.getNumInputs() == argAttrs.size());
192 builder, state, argAttrs, /*resultAttrs=*/{},
193 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
194}
195
196ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
197 auto buildFuncType =
198 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
200 std::string &) { return builder.getFunctionType(argTypes, results); };
201
203 parser, result, /*allowVariadic=*/false,
204 getFunctionTypeAttrName(result.name), buildFuncType,
205 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
206}
207
208void FuncOp::print(OpAsmPrinter &p) {
210 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
211 getArgAttrsAttrName(), getResAttrsAttrName());
212}
213
214/// Clone the internal blocks from this function into dest and all attributes
215/// from this function to dest.
216void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
217 // Add the attributes of this function to dest.
218 llvm::MapVector<StringAttr, Attribute> newAttrMap;
219 for (const auto &attr : dest->getAttrs())
220 newAttrMap.insert({attr.getName(), attr.getValue()});
221 for (const auto &attr : (*this)->getAttrs())
222 newAttrMap.insert({attr.getName(), attr.getValue()});
223
224 auto newAttrs = llvm::to_vector(llvm::map_range(
225 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
226 return NamedAttribute(attrPair.first, attrPair.second);
227 }));
228 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
229
230 // Clone the body.
231 getBody().cloneInto(&dest.getBody(), mapper);
232}
233
234/// Create a deep copy of this function and all of its blocks, remapping
235/// any operands that use values outside of the function using the map that is
236/// provided (leaving them alone if no entry is present). Replaces references
237/// to cloned sub-values with the corresponding value that is copied, and adds
238/// those mappings to the mapper.
239FuncOp FuncOp::clone(IRMapping &mapper) {
240 // Create the new function.
241 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
242
243 // If the function has a body, then the user might be deleting arguments to
244 // the function by specifying them in the mapper. If so, we don't add the
245 // argument to the input type vector.
246 if (!isExternal()) {
247 FunctionType oldType = getFunctionType();
248
249 unsigned oldNumArgs = oldType.getNumInputs();
250 SmallVector<Type, 4> newInputs;
251 newInputs.reserve(oldNumArgs);
252 for (unsigned i = 0; i != oldNumArgs; ++i)
253 if (!mapper.contains(getArgument(i)))
254 newInputs.push_back(oldType.getInput(i));
255
256 /// If any of the arguments were dropped, update the type and drop any
257 /// necessary argument attributes.
258 if (newInputs.size() != oldNumArgs) {
259 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
260 oldType.getResults()));
261
262 if (ArrayAttr argAttrs = getAllArgAttrs()) {
263 SmallVector<Attribute> newArgAttrs;
264 newArgAttrs.reserve(newInputs.size());
265 for (unsigned i = 0; i != oldNumArgs; ++i)
266 if (!mapper.contains(getArgument(i)))
267 newArgAttrs.push_back(argAttrs[i]);
268 newFunc.setAllArgAttrs(newArgAttrs);
269 }
270 }
271 }
272
273 /// Clone the current function into the new one and return it.
274 cloneInto(newFunc, mapper);
275 return newFunc;
276}
277FuncOp FuncOp::clone() {
278 IRMapping mapper;
279 return clone(mapper);
280}
281
282//===----------------------------------------------------------------------===//
283// ReturnOp
284//===----------------------------------------------------------------------===//
285
286LogicalResult ReturnOp::verify() {
287 auto function = cast<FuncOp>((*this)->getParentOp());
288
289 // The operand number and types must match the function signature.
290 const auto &results = function.getFunctionType().getResults();
291 if (getNumOperands() != results.size())
292 return emitOpError("has ")
293 << getNumOperands() << " operands, but enclosing function (@"
294 << function.getName() << ") returns " << results.size();
295
296 for (unsigned i = 0, e = results.size(); i != e; ++i)
297 if (getOperand(i).getType() != results[i])
298 return emitError() << "type of return operand " << i << " ("
299 << getOperand(i).getType()
300 << ") doesn't match function result type ("
301 << results[i] << ")"
302 << " in function @" << function.getName();
303
304 return success();
305}
306
307//===----------------------------------------------------------------------===//
308// TableGen'd op method definitions
309//===----------------------------------------------------------------------===//
310
311#define GET_OP_CLASSES
312#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
iterator_range< dialect_attr_iterator > dialect_attr_range
Definition Operation.h:634
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.