23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "llvm/Support/raw_ostream.h"
30 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
39 void FuncDialect::initialize() {
42 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
44 declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
45 declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
46 declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
54 if (ConstantOp::isBuildableWith(value, type))
55 return builder.
create<ConstantOp>(loc, type,
56 llvm::cast<FlatSymbolRefAttr>(value));
68 return emitOpError(
"requires a 'callee' symbol reference attribute");
71 return emitOpError() <<
"'" << fnAttr.getValue()
72 <<
"' does not reference a valid function";
75 auto fnType = fn.getFunctionType();
76 if (fnType.getNumInputs() != getNumOperands())
77 return emitOpError(
"incorrect number of operands for callee");
79 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
80 if (getOperand(i).
getType() != fnType.getInput(i))
81 return emitOpError(
"operand type mismatch: expected operand type ")
82 << fnType.getInput(i) <<
", but provided "
83 << getOperand(i).getType() <<
" for operand number " << i;
85 if (fnType.getNumResults() != getNumResults())
86 return emitOpError(
"incorrect number of results for callee");
88 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
89 if (getResult(i).
getType() != fnType.getResult(i)) {
90 auto diag = emitOpError(
"result type mismatch at index ") << i;
91 diag.attachNote() <<
" op result types: " << getResultTypes();
92 diag.attachNote() <<
"function result types: " << fnType.getResults();
99 FunctionType CallOp::getCalleeType() {
108 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
111 SymbolRefAttr calledFn;
117 indirectCall.getResultTypes(),
118 indirectCall.getArgOperands());
127 StringRef fnName = getValue();
134 return emitOpError() <<
"reference to undefined function '" << fnName
138 if (fn.getFunctionType() != type)
139 return emitOpError(
"reference to function with mismatched type");
145 return getValueAttr();
148 void ConstantOp::getAsmResultNames(
150 setNameFn(getResult(),
"f");
153 bool ConstantOp::isBuildableWith(
Attribute value,
Type type) {
154 return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
161 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
165 FuncOp::build(builder, state, name, type, attrs);
168 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
173 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
176 FuncOp func = create(location, name, type, attrs);
177 func.setAllArgAttrs(argAttrs);
186 state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
187 state.attributes.append(attrs.begin(), attrs.end());
190 if (argAttrs.empty())
192 assert(type.getNumInputs() == argAttrs.size());
194 builder, state, argAttrs, std::nullopt,
195 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
205 parser, result,
false,
206 getFunctionTypeAttrName(result.
name), buildFuncType,
207 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
212 p, *
this,
false, getFunctionTypeAttrName(),
213 getArgAttrsAttrName(), getResAttrsAttrName());
218 void FuncOp::cloneInto(FuncOp dest,
IRMapping &mapper) {
220 llvm::MapVector<StringAttr, Attribute> newAttrMap;
221 for (
const auto &attr : dest->getAttrs())
222 newAttrMap.insert({attr.getName(), attr.getValue()});
223 for (
const auto &attr : (*this)->getAttrs())
224 newAttrMap.insert({attr.getName(), attr.getValue()});
226 auto newAttrs = llvm::to_vector(llvm::map_range(
227 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
233 getBody().cloneInto(&dest.getBody(), mapper);
249 FunctionType oldType = getFunctionType();
251 unsigned oldNumArgs = oldType.getNumInputs();
253 newInputs.reserve(oldNumArgs);
254 for (
unsigned i = 0; i != oldNumArgs; ++i)
255 if (!mapper.
contains(getArgument(i)))
256 newInputs.push_back(oldType.getInput(i));
260 if (newInputs.size() != oldNumArgs) {
262 oldType.getResults()));
264 if (ArrayAttr argAttrs = getAllArgAttrs()) {
266 newArgAttrs.reserve(newInputs.size());
267 for (
unsigned i = 0; i != oldNumArgs; ++i)
268 if (!mapper.
contains(getArgument(i)))
269 newArgAttrs.push_back(argAttrs[i]);
270 newFunc.setAllArgAttrs(newArgAttrs);
276 cloneInto(newFunc, mapper);
281 return clone(mapper);
289 auto function = cast<FuncOp>((*this)->getParentOp());
292 const auto &results =
function.getFunctionType().getResults();
293 if (getNumOperands() != results.size())
294 return emitOpError(
"has ")
295 << getNumOperands() <<
" operands, but enclosing function (@"
296 <<
function.getName() <<
") returns " << results.size();
298 for (
unsigned i = 0, e = results.size(); i != e; ++i)
299 if (getOperand(i).
getType() != results[i])
300 return emitError() <<
"type of return operand " << i <<
" ("
301 << getOperand(i).getType()
302 <<
") doesn't match function result type ("
304 <<
" in function @" <<
function.getName();
313 #define GET_OP_CLASSES
314 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class is a general helper class for creating context-global objects like types,...
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
StringAttr getStringAttr(const Twine &bytes)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.