23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/STLExtras.h"
27 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
36 void FuncDialect::initialize() {
39 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
41 declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>();
42 declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
43 declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
44 declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
52 if (ConstantOp::isBuildableWith(value, type))
53 return ConstantOp::create(builder, loc, type,
54 llvm::cast<FlatSymbolRefAttr>(value));
66 return emitOpError(
"requires a 'callee' symbol reference attribute");
69 return emitOpError() <<
"'" << fnAttr.getValue()
70 <<
"' does not reference a valid function";
73 auto fnType = fn.getFunctionType();
74 if (fnType.getNumInputs() != getNumOperands())
75 return emitOpError(
"incorrect number of operands for callee");
77 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
78 if (getOperand(i).
getType() != fnType.getInput(i))
79 return emitOpError(
"operand type mismatch: expected operand type ")
80 << fnType.getInput(i) <<
", but provided "
81 << getOperand(i).getType() <<
" for operand number " << i;
83 if (fnType.getNumResults() != getNumResults())
84 return emitOpError(
"incorrect number of results for callee");
86 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
87 if (getResult(i).
getType() != fnType.getResult(i)) {
88 auto diag = emitOpError(
"result type mismatch at index ") << i;
89 diag.attachNote() <<
" op result types: " << getResultTypes();
90 diag.attachNote() <<
"function result types: " << fnType.getResults();
97 FunctionType CallOp::getCalleeType() {
106 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
109 SymbolRefAttr calledFn;
115 indirectCall.getResultTypes(),
116 indirectCall.getArgOperands());
125 StringRef fnName = getValue();
132 return emitOpError() <<
"reference to undefined function '" << fnName
136 if (fn.getFunctionType() != type)
137 return emitOpError(
"reference to function with mismatched type");
143 return getValueAttr();
146 void ConstantOp::getAsmResultNames(
148 setNameFn(getResult(),
"f");
151 bool ConstantOp::isBuildableWith(
Attribute value,
Type type) {
152 return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
159 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
163 FuncOp::build(builder, state, name, type, attrs);
166 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
171 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
174 FuncOp func = create(location, name, type, attrs);
175 func.setAllArgAttrs(argAttrs);
184 state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
185 state.attributes.append(attrs.begin(), attrs.end());
188 if (argAttrs.empty())
190 assert(type.getNumInputs() == argAttrs.size());
192 builder, state, argAttrs, {},
193 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
203 parser, result,
false,
204 getFunctionTypeAttrName(result.
name), buildFuncType,
205 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
210 p, *
this,
false, getFunctionTypeAttrName(),
211 getArgAttrsAttrName(), getResAttrsAttrName());
216 void FuncOp::cloneInto(FuncOp dest,
IRMapping &mapper) {
218 llvm::MapVector<StringAttr, Attribute> newAttrMap;
219 for (
const auto &attr : dest->getAttrs())
220 newAttrMap.insert({attr.getName(), attr.getValue()});
221 for (
const auto &attr : (*this)->getAttrs())
222 newAttrMap.insert({attr.getName(), attr.getValue()});
224 auto newAttrs = llvm::to_vector(llvm::map_range(
225 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
231 getBody().cloneInto(&dest.getBody(), mapper);
247 FunctionType oldType = getFunctionType();
249 unsigned oldNumArgs = oldType.getNumInputs();
251 newInputs.reserve(oldNumArgs);
252 for (
unsigned i = 0; i != oldNumArgs; ++i)
253 if (!mapper.
contains(getArgument(i)))
254 newInputs.push_back(oldType.getInput(i));
258 if (newInputs.size() != oldNumArgs) {
260 oldType.getResults()));
262 if (ArrayAttr argAttrs = getAllArgAttrs()) {
264 newArgAttrs.reserve(newInputs.size());
265 for (
unsigned i = 0; i != oldNumArgs; ++i)
266 if (!mapper.
contains(getArgument(i)))
267 newArgAttrs.push_back(argAttrs[i]);
268 newFunc.setAllArgAttrs(newArgAttrs);
274 cloneInto(newFunc, mapper);
279 return clone(mapper);
287 auto function = cast<FuncOp>((*this)->getParentOp());
290 const auto &results =
function.getFunctionType().getResults();
291 if (getNumOperands() != results.size())
292 return emitOpError(
"has ")
293 << getNumOperands() <<
" operands, but enclosing function (@"
294 <<
function.getName() <<
") returns " << results.size();
296 for (
unsigned i = 0, e = results.size(); i != e; ++i)
297 if (getOperand(i).
getType() != results[i])
298 return emitError() <<
"type of return operand " << i <<
" ("
299 << getOperand(i).getType()
300 <<
") doesn't match function result type ("
302 <<
" in function @" <<
function.getName();
311 #define GET_OP_CLASSES
312 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class is a general helper class for creating context-global objects like types,...
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
StringAttr getStringAttr(const Twine &bytes)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.