MLIR  22.0.0git
FuncOps.cpp
Go to the documentation of this file.
1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/IR/Value.h"
23 #include "llvm/ADT/APFloat.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/STLExtras.h"
26 
27 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
28 
29 using namespace mlir;
30 using namespace mlir::func;
31 
32 //===----------------------------------------------------------------------===//
33 // FuncDialect
34 //===----------------------------------------------------------------------===//
35 
36 void FuncDialect::initialize() {
37  addOperations<
38 #define GET_OP_LIST
39 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
40  >();
41  declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>();
42  declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
43  declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
44  declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
45  FuncOp, ReturnOp>();
46 }
47 
48 /// Materialize a single constant operation from a given attribute value with
49 /// the desired resultant type.
51  Type type, Location loc) {
52  if (ConstantOp::isBuildableWith(value, type))
53  return ConstantOp::create(builder, loc, type,
54  llvm::cast<FlatSymbolRefAttr>(value));
55  return nullptr;
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // CallOp
60 //===----------------------------------------------------------------------===//
61 
62 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
63  // Check that the callee attribute was specified.
64  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
65  if (!fnAttr)
66  return emitOpError("requires a 'callee' symbol reference attribute");
67  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
68  if (!fn)
69  return emitOpError() << "'" << fnAttr.getValue()
70  << "' does not reference a valid function";
71 
72  // Verify that the operand and result types match the callee.
73  auto fnType = fn.getFunctionType();
74  if (fnType.getNumInputs() != getNumOperands())
75  return emitOpError("incorrect number of operands for callee");
76 
77  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
78  if (getOperand(i).getType() != fnType.getInput(i))
79  return emitOpError("operand type mismatch: expected operand type ")
80  << fnType.getInput(i) << ", but provided "
81  << getOperand(i).getType() << " for operand number " << i;
82 
83  if (fnType.getNumResults() != getNumResults())
84  return emitOpError("incorrect number of results for callee");
85 
86  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
87  if (getResult(i).getType() != fnType.getResult(i)) {
88  auto diag = emitOpError("result type mismatch at index ") << i;
89  diag.attachNote() << " op result types: " << getResultTypes();
90  diag.attachNote() << "function result types: " << fnType.getResults();
91  return diag;
92  }
93 
94  return success();
95 }
96 
97 FunctionType CallOp::getCalleeType() {
98  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // CallIndirectOp
103 //===----------------------------------------------------------------------===//
104 
105 /// Fold indirect calls that have a constant function as the callee operand.
106 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
107  PatternRewriter &rewriter) {
108  // Check that the callee is a constant callee.
109  SymbolRefAttr calledFn;
110  if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
111  return failure();
112 
113  // Replace with a direct call.
114  rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
115  indirectCall.getResultTypes(),
116  indirectCall.getArgOperands());
117  return success();
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // ConstantOp
122 //===----------------------------------------------------------------------===//
123 
124 LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
125  StringRef fnName = getValue();
126  Type type = getType();
127 
128  // Try to find the referenced function.
129  auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(
130  this->getOperation(), StringAttr::get(getContext(), fnName));
131  if (!fn)
132  return emitOpError() << "reference to undefined function '" << fnName
133  << "'";
134 
135  // Check that the referenced function has the correct type.
136  if (fn.getFunctionType() != type)
137  return emitOpError("reference to function with mismatched type");
138 
139  return success();
140 }
141 
142 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
143  return getValueAttr();
144 }
145 
146 void ConstantOp::getAsmResultNames(
147  function_ref<void(Value, StringRef)> setNameFn) {
148  setNameFn(getResult(), "f");
149 }
150 
151 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
152  return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
153 }
154 
155 //===----------------------------------------------------------------------===//
156 // FuncOp
157 //===----------------------------------------------------------------------===//
158 
159 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
160  ArrayRef<NamedAttribute> attrs) {
161  OpBuilder builder(location->getContext());
162  OperationState state(location, getOperationName());
163  FuncOp::build(builder, state, name, type, attrs);
164  return cast<FuncOp>(Operation::create(state));
165 }
166 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
168  SmallVector<NamedAttribute, 8> attrRef(attrs);
169  return create(location, name, type, llvm::ArrayRef(attrRef));
170 }
171 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
173  ArrayRef<DictionaryAttr> argAttrs) {
174  FuncOp func = create(location, name, type, attrs);
175  func.setAllArgAttrs(argAttrs);
176  return func;
177 }
178 
179 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
180  FunctionType type, ArrayRef<NamedAttribute> attrs,
181  ArrayRef<DictionaryAttr> argAttrs) {
182  state.addAttribute(SymbolTable::getSymbolAttrName(),
183  builder.getStringAttr(name));
184  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
185  state.attributes.append(attrs.begin(), attrs.end());
186  state.addRegion();
187 
188  if (argAttrs.empty())
189  return;
190  assert(type.getNumInputs() == argAttrs.size());
192  builder, state, argAttrs, /*resultAttrs=*/{},
193  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
194 }
195 
196 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
197  auto buildFuncType =
198  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
200  std::string &) { return builder.getFunctionType(argTypes, results); };
201 
203  parser, result, /*allowVariadic=*/false,
204  getFunctionTypeAttrName(result.name), buildFuncType,
205  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
206 }
207 
208 void FuncOp::print(OpAsmPrinter &p) {
210  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
211  getArgAttrsAttrName(), getResAttrsAttrName());
212 }
213 
214 /// Clone the internal blocks from this function into dest and all attributes
215 /// from this function to dest.
216 void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
217  // Add the attributes of this function to dest.
218  llvm::MapVector<StringAttr, Attribute> newAttrMap;
219  for (const auto &attr : dest->getAttrs())
220  newAttrMap.insert({attr.getName(), attr.getValue()});
221  for (const auto &attr : (*this)->getAttrs())
222  newAttrMap.insert({attr.getName(), attr.getValue()});
223 
224  auto newAttrs = llvm::to_vector(llvm::map_range(
225  newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
226  return NamedAttribute(attrPair.first, attrPair.second);
227  }));
228  dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
229 
230  // Clone the body.
231  getBody().cloneInto(&dest.getBody(), mapper);
232 }
233 
234 /// Create a deep copy of this function and all of its blocks, remapping
235 /// any operands that use values outside of the function using the map that is
236 /// provided (leaving them alone if no entry is present). Replaces references
237 /// to cloned sub-values with the corresponding value that is copied, and adds
238 /// those mappings to the mapper.
239 FuncOp FuncOp::clone(IRMapping &mapper) {
240  // Create the new function.
241  FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
242 
243  // If the function has a body, then the user might be deleting arguments to
244  // the function by specifying them in the mapper. If so, we don't add the
245  // argument to the input type vector.
246  if (!isExternal()) {
247  FunctionType oldType = getFunctionType();
248 
249  unsigned oldNumArgs = oldType.getNumInputs();
250  SmallVector<Type, 4> newInputs;
251  newInputs.reserve(oldNumArgs);
252  for (unsigned i = 0; i != oldNumArgs; ++i)
253  if (!mapper.contains(getArgument(i)))
254  newInputs.push_back(oldType.getInput(i));
255 
256  /// If any of the arguments were dropped, update the type and drop any
257  /// necessary argument attributes.
258  if (newInputs.size() != oldNumArgs) {
259  newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
260  oldType.getResults()));
261 
262  if (ArrayAttr argAttrs = getAllArgAttrs()) {
263  SmallVector<Attribute> newArgAttrs;
264  newArgAttrs.reserve(newInputs.size());
265  for (unsigned i = 0; i != oldNumArgs; ++i)
266  if (!mapper.contains(getArgument(i)))
267  newArgAttrs.push_back(argAttrs[i]);
268  newFunc.setAllArgAttrs(newArgAttrs);
269  }
270  }
271  }
272 
273  /// Clone the current function into the new one and return it.
274  cloneInto(newFunc, mapper);
275  return newFunc;
276 }
277 FuncOp FuncOp::clone() {
278  IRMapping mapper;
279  return clone(mapper);
280 }
281 
282 //===----------------------------------------------------------------------===//
283 // ReturnOp
284 //===----------------------------------------------------------------------===//
285 
286 LogicalResult ReturnOp::verify() {
287  auto function = cast<FuncOp>((*this)->getParentOp());
288 
289  // The operand number and types must match the function signature.
290  const auto &results = function.getFunctionType().getResults();
291  if (getNumOperands() != results.size())
292  return emitOpError("has ")
293  << getNumOperands() << " operands, but enclosing function (@"
294  << function.getName() << ") returns " << results.size();
295 
296  for (unsigned i = 0, e = results.size(); i != e; ++i)
297  if (getOperand(i).getType() != results[i])
298  return emitError() << "type of return operand " << i << " ("
299  << getOperand(i).getType()
300  << ") doesn't match function result type ("
301  << results[i] << ")"
302  << " in function @" << function.getName();
303 
304  return success();
305 }
306 
307 //===----------------------------------------------------------------------===//
308 // TableGen'd op method definitions
309 //===----------------------------------------------------------------------===//
310 
311 #define GET_OP_CLASSES
312 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:205
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.