23#include "llvm/ADT/APFloat.h"
24#include "llvm/ADT/MapVector.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallVectorExtras.h"
28#include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
37void FuncDialect::initialize() {
40#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
42 declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>();
43 declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
44 declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
45 declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
53 if (ConstantOp::isBuildableWith(value, type))
54 return ConstantOp::create(builder, loc, type,
55 llvm::cast<FlatSymbolRefAttr>(value));
67 return emitOpError(
"requires a 'callee' symbol reference attribute");
71 <<
"' does not reference a valid function";
74 auto fnType = fn.getFunctionType();
75 if (fnType.getNumInputs() != getNumOperands())
76 return emitOpError(
"incorrect number of operands for callee");
78 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
79 if (getOperand(i).
getType() != fnType.getInput(i))
80 return emitOpError(
"operand type mismatch: expected operand type ")
81 << fnType.getInput(i) <<
", but provided "
82 << getOperand(i).getType() <<
" for operand number " << i;
84 if (fnType.getNumResults() != getNumResults())
85 return emitOpError(
"incorrect number of results for callee");
87 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
88 if (getResult(i).
getType() != fnType.getResult(i)) {
90 diag.attachNote() <<
" op result types: " << getResultTypes();
91 diag.attachNote() <<
"function result types: " << fnType.getResults();
98FunctionType CallOp::getCalleeType() {
99 return FunctionType::get(
getContext(), getOperandTypes(), getResultTypes());
107LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
110 SymbolRefAttr calledFn;
116 indirectCall.getResultTypes(),
117 indirectCall.getArgOperands());
126 StringRef fnName = getValue();
131 this->getOperation(), StringAttr::get(
getContext(), fnName));
133 return emitOpError() <<
"reference to undefined function '" << fnName
137 if (fn.getFunctionType() != type)
138 return emitOpError(
"reference to function with mismatched type");
144 return getValueAttr();
147void ConstantOp::getAsmResultNames(
149 setNameFn(getResult(),
"f");
152bool ConstantOp::isBuildableWith(
Attribute value,
Type type) {
153 return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
160FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
164 FuncOp::build(builder, state, name, type, attrs);
167FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
172FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
175 FuncOp
func = create(location, name, type, attrs);
176 func.setAllArgAttrs(argAttrs);
185 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
189 if (argAttrs.empty())
191 assert(type.getNumInputs() == argAttrs.size());
193 builder, state, argAttrs, {},
194 getArgAttrsAttrName(state.
name), getResAttrsAttrName(state.
name));
205 getFunctionTypeAttrName(
result.name), buildFuncType,
206 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
211 p, *
this,
false, getFunctionTypeAttrName(),
212 getArgAttrsAttrName(), getResAttrsAttrName());
217void FuncOp::cloneInto(FuncOp dest,
IRMapping &mapper) {
219 llvm::MapVector<StringAttr, Attribute> newAttrMap;
220 for (
const auto &attr : dest->getAttrs())
221 newAttrMap.insert({attr.getName(), attr.getValue()});
222 for (
const auto &attr : (*this)->getAttrs())
223 newAttrMap.insert({attr.getName(), attr.getValue()});
225 auto newAttrs = llvm::map_to_vector(
226 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
229 dest->setAttrs(DictionaryAttr::get(
getContext(), newAttrs));
232 getBody().cloneInto(&dest.getBody(), mapper);
248 FunctionType oldType = getFunctionType();
250 unsigned oldNumArgs = oldType.getNumInputs();
252 newInputs.reserve(oldNumArgs);
253 for (
unsigned i = 0; i != oldNumArgs; ++i)
254 if (!mapper.
contains(getArgument(i)))
255 newInputs.push_back(oldType.getInput(i));
259 if (newInputs.size() != oldNumArgs) {
260 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
261 oldType.getResults()));
263 if (
ArrayAttr argAttrs = getAllArgAttrs()) {
265 newArgAttrs.reserve(newInputs.size());
266 for (
unsigned i = 0; i != oldNumArgs; ++i)
267 if (!mapper.
contains(getArgument(i)))
268 newArgAttrs.push_back(argAttrs[i]);
269 newFunc.setAllArgAttrs(newArgAttrs);
275 cloneInto(newFunc, mapper);
278FuncOp FuncOp::clone() {
280 return clone(mapper);
287LogicalResult ReturnOp::verify() {
288 auto function = cast<FuncOp>((*this)->getParentOp());
291 const auto &results = function.getFunctionType().getResults();
292 if (getNumOperands() != results.size())
294 << getNumOperands() <<
" operands, but enclosing function (@"
295 << function.getName() <<
") returns " << results.size();
297 for (
unsigned i = 0, e = results.size(); i != e; ++i)
298 if (getOperand(i).
getType() != results[i])
299 return emitError() <<
"type of return operand " << i <<
" ("
300 << getOperand(i).getType()
301 <<
") doesn't match function result type ("
303 <<
" in function @" << function.getName();
312#define GET_OP_CLASSES
313#include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class is a general helper class for creating context-global objects like types,...
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
StringAttr getStringAttr(const Twine &bytes)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
iterator_range< dialect_attr_iterator > dialect_attr_range
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.