MLIR 23.0.0git
FuncOps.cpp
Go to the documentation of this file.
1//===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/IR/IRMapping.h"
16#include "mlir/IR/Matchers.h"
20#include "mlir/IR/Value.h"
23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/MapVector.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallVectorExtras.h"
27
28#include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
29
30using namespace mlir;
31using namespace mlir::func;
32
33//===----------------------------------------------------------------------===//
34// FuncDialect
35//===----------------------------------------------------------------------===//
36
37void FuncDialect::initialize() {
38 addOperations<
39#define GET_OP_LIST
40#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
41 >();
42 declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>();
43 declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
44 declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
45 declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
46 FuncOp, ReturnOp>();
47}
48
49/// Materialize a single constant operation from a given attribute value with
50/// the desired resultant type.
51Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value,
52 Type type, Location loc) {
53 if (ConstantOp::isBuildableWith(value, type))
54 return ConstantOp::create(builder, loc, type,
55 llvm::cast<FlatSymbolRefAttr>(value));
56 return nullptr;
57}
58
59//===----------------------------------------------------------------------===//
60// CallOp
61//===----------------------------------------------------------------------===//
62
63LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
64 // Check that the callee attribute was specified.
65 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
66 if (!fnAttr)
67 return emitOpError("requires a 'callee' symbol reference attribute");
68 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
69 if (!fn)
70 return emitOpError() << "'" << fnAttr.getValue()
71 << "' does not reference a valid function";
72
73 // Verify that the operand and result types match the callee.
74 auto fnType = fn.getFunctionType();
75 if (fnType.getNumInputs() != getNumOperands())
76 return emitOpError("incorrect number of operands for callee");
77
78 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
79 if (getOperand(i).getType() != fnType.getInput(i))
80 return emitOpError("operand type mismatch: expected operand type ")
81 << fnType.getInput(i) << ", but provided "
82 << getOperand(i).getType() << " for operand number " << i;
83
84 if (fnType.getNumResults() != getNumResults())
85 return emitOpError("incorrect number of results for callee");
86
87 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
88 if (getResult(i).getType() != fnType.getResult(i)) {
89 auto diag = emitOpError("result type mismatch at index ") << i;
90 diag.attachNote() << " op result types: " << getResultTypes();
91 diag.attachNote() << "function result types: " << fnType.getResults();
92 return diag;
93 }
94
95 return success();
96}
97
98FunctionType CallOp::getCalleeType() {
99 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
100}
101
102//===----------------------------------------------------------------------===//
103// CallIndirectOp
104//===----------------------------------------------------------------------===//
105
106/// Fold indirect calls that have a constant function as the callee operand.
107LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
108 PatternRewriter &rewriter) {
109 // Check that the callee is a constant callee.
110 SymbolRefAttr calledFn;
111 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
112 return failure();
113
114 // Replace with a direct call.
115 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
116 indirectCall.getResultTypes(),
117 indirectCall.getArgOperands());
118 return success();
119}
120
121//===----------------------------------------------------------------------===//
122// ConstantOp
123//===----------------------------------------------------------------------===//
124
125LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
126 StringRef fnName = getValue();
127 Type type = getType();
128
129 // Try to find the referenced function.
130 auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(
131 this->getOperation(), StringAttr::get(getContext(), fnName));
132 if (!fn)
133 return emitOpError() << "reference to undefined function '" << fnName
134 << "'";
135
136 // Check that the referenced function has the correct type.
137 if (fn.getFunctionType() != type)
138 return emitOpError("reference to function with mismatched type");
139
140 return success();
141}
142
143OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
144 return getValueAttr();
145}
146
147void ConstantOp::getAsmResultNames(
148 function_ref<void(Value, StringRef)> setNameFn) {
149 setNameFn(getResult(), "f");
150}
151
152bool ConstantOp::isBuildableWith(Attribute value, Type type) {
153 return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
154}
155
156//===----------------------------------------------------------------------===//
157// FuncOp
158//===----------------------------------------------------------------------===//
159
160FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
162 OpBuilder builder(location->getContext());
163 OperationState state(location, getOperationName());
164 FuncOp::build(builder, state, name, type, attrs);
165 return cast<FuncOp>(Operation::create(state));
166}
167FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
169 SmallVector<NamedAttribute, 8> attrRef(attrs);
170 return create(location, name, type, llvm::ArrayRef(attrRef));
171}
172FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
174 ArrayRef<DictionaryAttr> argAttrs) {
175 FuncOp func = create(location, name, type, attrs);
176 func.setAllArgAttrs(argAttrs);
177 return func;
178}
179
180void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
181 FunctionType type, ArrayRef<NamedAttribute> attrs,
182 ArrayRef<DictionaryAttr> argAttrs) {
184 builder.getStringAttr(name));
185 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
186 state.attributes.append(attrs.begin(), attrs.end());
187 state.addRegion();
188
189 if (argAttrs.empty())
190 return;
191 assert(type.getNumInputs() == argAttrs.size());
193 builder, state, argAttrs, /*resultAttrs=*/{},
194 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
195}
196
197ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
198 auto buildFuncType =
199 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
201 std::string &) { return builder.getFunctionType(argTypes, results); };
202
204 parser, result, /*allowVariadic=*/false,
205 getFunctionTypeAttrName(result.name), buildFuncType,
206 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
207}
208
209void FuncOp::print(OpAsmPrinter &p) {
211 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
212 getArgAttrsAttrName(), getResAttrsAttrName());
213}
214
215/// Clone the internal blocks from this function into dest and all attributes
216/// from this function to dest.
217void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
218 // Add the attributes of this function to dest.
219 llvm::MapVector<StringAttr, Attribute> newAttrMap;
220 for (const auto &attr : dest->getAttrs())
221 newAttrMap.insert({attr.getName(), attr.getValue()});
222 for (const auto &attr : (*this)->getAttrs())
223 newAttrMap.insert({attr.getName(), attr.getValue()});
224
225 auto newAttrs = llvm::map_to_vector(
226 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
227 return NamedAttribute(attrPair.first, attrPair.second);
228 });
229 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
230
231 // Clone the body.
232 getBody().cloneInto(&dest.getBody(), mapper);
233}
234
235/// Create a deep copy of this function and all of its blocks, remapping
236/// any operands that use values outside of the function using the map that is
237/// provided (leaving them alone if no entry is present). Replaces references
238/// to cloned sub-values with the corresponding value that is copied, and adds
239/// those mappings to the mapper.
240FuncOp FuncOp::clone(IRMapping &mapper) {
241 // Create the new function.
242 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
243
244 // If the function has a body, then the user might be deleting arguments to
245 // the function by specifying them in the mapper. If so, we don't add the
246 // argument to the input type vector.
247 if (!isExternal()) {
248 FunctionType oldType = getFunctionType();
249
250 unsigned oldNumArgs = oldType.getNumInputs();
251 SmallVector<Type, 4> newInputs;
252 newInputs.reserve(oldNumArgs);
253 for (unsigned i = 0; i != oldNumArgs; ++i)
254 if (!mapper.contains(getArgument(i)))
255 newInputs.push_back(oldType.getInput(i));
256
257 /// If any of the arguments were dropped, update the type and drop any
258 /// necessary argument attributes.
259 if (newInputs.size() != oldNumArgs) {
260 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
261 oldType.getResults()));
262
263 if (ArrayAttr argAttrs = getAllArgAttrs()) {
264 SmallVector<Attribute> newArgAttrs;
265 newArgAttrs.reserve(newInputs.size());
266 for (unsigned i = 0; i != oldNumArgs; ++i)
267 if (!mapper.contains(getArgument(i)))
268 newArgAttrs.push_back(argAttrs[i]);
269 newFunc.setAllArgAttrs(newArgAttrs);
270 }
271 }
272 }
273
274 /// Clone the current function into the new one and return it.
275 cloneInto(newFunc, mapper);
276 return newFunc;
277}
278FuncOp FuncOp::clone() {
279 IRMapping mapper;
280 return clone(mapper);
281}
282
283//===----------------------------------------------------------------------===//
284// ReturnOp
285//===----------------------------------------------------------------------===//
286
287LogicalResult ReturnOp::verify() {
288 auto function = cast<FuncOp>((*this)->getParentOp());
289
290 // The operand number and types must match the function signature.
291 const auto &results = function.getFunctionType().getResults();
292 if (getNumOperands() != results.size())
293 return emitOpError("has ")
294 << getNumOperands() << " operands, but enclosing function (@"
295 << function.getName() << ") returns " << results.size();
296
297 for (unsigned i = 0, e = results.size(); i != e; ++i)
298 if (getOperand(i).getType() != results[i])
299 return emitError() << "type of return operand " << i << " ("
300 << getOperand(i).getType()
301 << ") doesn't match function result type ("
302 << results[i] << ")"
303 << " in function @" << function.getName();
304
305 return success();
306}
307
308//===----------------------------------------------------------------------===//
309// TableGen'd op method definitions
310//===----------------------------------------------------------------------===//
311
312#define GET_OP_CLASSES
313#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition IRMapping.h:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
iterator_range< dialect_attr_iterator > dialect_attr_range
Definition Operation.h:634
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.