23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/MapVector.h"
25#include "llvm/ADT/STLExtras.h"
27#include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
36void FuncDialect::initialize() {
39#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
41 declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>();
42 declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
43 declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
44 declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
52 if (ConstantOp::isBuildableWith(value, type))
53 return ConstantOp::create(builder, loc, type,
54 llvm::cast<FlatSymbolRefAttr>(value));
66 return emitOpError(
"requires a 'callee' symbol reference attribute");
70 <<
"' does not reference a valid function";
73 auto fnType = fn.getFunctionType();
74 if (fnType.getNumInputs() != getNumOperands())
75 return emitOpError(
"incorrect number of operands for callee");
77 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
78 if (getOperand(i).
getType() != fnType.getInput(i))
79 return emitOpError(
"operand type mismatch: expected operand type ")
80 << fnType.getInput(i) <<
", but provided "
81 << getOperand(i).getType() <<
" for operand number " << i;
83 if (fnType.getNumResults() != getNumResults())
84 return emitOpError(
"incorrect number of results for callee");
86 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
87 if (getResult(i).
getType() != fnType.getResult(i)) {
89 diag.attachNote() <<
" op result types: " << getResultTypes();
90 diag.attachNote() <<
"function result types: " << fnType.getResults();
97FunctionType CallOp::getCalleeType() {
98 return FunctionType::get(
getContext(), getOperandTypes(), getResultTypes());
106LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
109 SymbolRefAttr calledFn;
115 indirectCall.getResultTypes(),
116 indirectCall.getArgOperands());
125 StringRef fnName = getValue();
130 this->getOperation(), StringAttr::get(
getContext(), fnName));
132 return emitOpError() <<
"reference to undefined function '" << fnName
136 if (fn.getFunctionType() != type)
137 return emitOpError(
"reference to function with mismatched type");
143 return getValueAttr();
146void ConstantOp::getAsmResultNames(
148 setNameFn(getResult(),
"f");
151bool ConstantOp::isBuildableWith(
Attribute value,
Type type) {
152 return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
159FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
163 FuncOp::build(builder, state, name, type, attrs);
166FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
171FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
174 FuncOp
func = create(location, name, type, attrs);
175 func.setAllArgAttrs(argAttrs);
184 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
188 if (argAttrs.empty())
190 assert(type.getNumInputs() == argAttrs.size());
192 builder, state, argAttrs, {},
193 getArgAttrsAttrName(state.
name), getResAttrsAttrName(state.
name));
204 getFunctionTypeAttrName(
result.name), buildFuncType,
205 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
210 p, *
this,
false, getFunctionTypeAttrName(),
211 getArgAttrsAttrName(), getResAttrsAttrName());
216void FuncOp::cloneInto(FuncOp dest,
IRMapping &mapper) {
218 llvm::MapVector<StringAttr, Attribute> newAttrMap;
219 for (
const auto &attr : dest->getAttrs())
220 newAttrMap.insert({attr.getName(), attr.getValue()});
221 for (
const auto &attr : (*this)->getAttrs())
222 newAttrMap.insert({attr.getName(), attr.getValue()});
224 auto newAttrs = llvm::to_vector(llvm::map_range(
225 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
228 dest->setAttrs(DictionaryAttr::get(
getContext(), newAttrs));
231 getBody().cloneInto(&dest.getBody(), mapper);
247 FunctionType oldType = getFunctionType();
249 unsigned oldNumArgs = oldType.getNumInputs();
251 newInputs.reserve(oldNumArgs);
252 for (
unsigned i = 0; i != oldNumArgs; ++i)
253 if (!mapper.
contains(getArgument(i)))
254 newInputs.push_back(oldType.getInput(i));
258 if (newInputs.size() != oldNumArgs) {
259 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
260 oldType.getResults()));
262 if (
ArrayAttr argAttrs = getAllArgAttrs()) {
264 newArgAttrs.reserve(newInputs.size());
265 for (
unsigned i = 0; i != oldNumArgs; ++i)
266 if (!mapper.
contains(getArgument(i)))
267 newArgAttrs.push_back(argAttrs[i]);
268 newFunc.setAllArgAttrs(newArgAttrs);
274 cloneInto(newFunc, mapper);
277FuncOp FuncOp::clone() {
279 return clone(mapper);
286LogicalResult ReturnOp::verify() {
287 auto function = cast<FuncOp>((*this)->getParentOp());
290 const auto &results = function.getFunctionType().getResults();
291 if (getNumOperands() != results.size())
293 << getNumOperands() <<
" operands, but enclosing function (@"
294 << function.getName() <<
") returns " << results.size();
296 for (
unsigned i = 0, e = results.size(); i != e; ++i)
297 if (getOperand(i).
getType() != results[i])
298 return emitError() <<
"type of return operand " << i <<
" ("
299 << getOperand(i).getType()
300 <<
") doesn't match function result type ("
302 <<
" in function @" << function.getName();
311#define GET_OP_CLASSES
312#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class is a general helper class for creating context-global objects like types,...
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
StringAttr getStringAttr(const Twine &bytes)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
iterator_range< dialect_attr_iterator > dialect_attr_range
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.