MLIR  21.0.0git
FuncOps.cpp
Go to the documentation of this file.
1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Value.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <numeric>
30 
31 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
32 
33 using namespace mlir;
34 using namespace mlir::func;
35 
36 //===----------------------------------------------------------------------===//
37 // FuncDialect
38 //===----------------------------------------------------------------------===//
39 
40 void FuncDialect::initialize() {
41  addOperations<
42 #define GET_OP_LIST
43 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
44  >();
45  declarePromisedInterface<ConvertToEmitCPatternInterface, FuncDialect>();
46  declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
47  declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
48  declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
49  FuncOp, ReturnOp>();
50 }
51 
52 /// Materialize a single constant operation from a given attribute value with
53 /// the desired resultant type.
55  Type type, Location loc) {
56  if (ConstantOp::isBuildableWith(value, type))
57  return builder.create<ConstantOp>(loc, type,
58  llvm::cast<FlatSymbolRefAttr>(value));
59  return nullptr;
60 }
61 
62 //===----------------------------------------------------------------------===//
63 // CallOp
64 //===----------------------------------------------------------------------===//
65 
66 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
67  // Check that the callee attribute was specified.
68  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
69  if (!fnAttr)
70  return emitOpError("requires a 'callee' symbol reference attribute");
71  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
72  if (!fn)
73  return emitOpError() << "'" << fnAttr.getValue()
74  << "' does not reference a valid function";
75 
76  // Verify that the operand and result types match the callee.
77  auto fnType = fn.getFunctionType();
78  if (fnType.getNumInputs() != getNumOperands())
79  return emitOpError("incorrect number of operands for callee");
80 
81  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
82  if (getOperand(i).getType() != fnType.getInput(i))
83  return emitOpError("operand type mismatch: expected operand type ")
84  << fnType.getInput(i) << ", but provided "
85  << getOperand(i).getType() << " for operand number " << i;
86 
87  if (fnType.getNumResults() != getNumResults())
88  return emitOpError("incorrect number of results for callee");
89 
90  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
91  if (getResult(i).getType() != fnType.getResult(i)) {
92  auto diag = emitOpError("result type mismatch at index ") << i;
93  diag.attachNote() << " op result types: " << getResultTypes();
94  diag.attachNote() << "function result types: " << fnType.getResults();
95  return diag;
96  }
97 
98  return success();
99 }
100 
101 FunctionType CallOp::getCalleeType() {
102  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // CallIndirectOp
107 //===----------------------------------------------------------------------===//
108 
109 /// Fold indirect calls that have a constant function as the callee operand.
110 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
111  PatternRewriter &rewriter) {
112  // Check that the callee is a constant callee.
113  SymbolRefAttr calledFn;
114  if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
115  return failure();
116 
117  // Replace with a direct call.
118  rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
119  indirectCall.getResultTypes(),
120  indirectCall.getArgOperands());
121  return success();
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // ConstantOp
126 //===----------------------------------------------------------------------===//
127 
128 LogicalResult ConstantOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
129  StringRef fnName = getValue();
130  Type type = getType();
131 
132  // Try to find the referenced function.
133  auto fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(
134  this->getOperation(), StringAttr::get(getContext(), fnName));
135  if (!fn)
136  return emitOpError() << "reference to undefined function '" << fnName
137  << "'";
138 
139  // Check that the referenced function has the correct type.
140  if (fn.getFunctionType() != type)
141  return emitOpError("reference to function with mismatched type");
142 
143  return success();
144 }
145 
146 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
147  return getValueAttr();
148 }
149 
150 void ConstantOp::getAsmResultNames(
151  function_ref<void(Value, StringRef)> setNameFn) {
152  setNameFn(getResult(), "f");
153 }
154 
155 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
156  return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // FuncOp
161 //===----------------------------------------------------------------------===//
162 
163 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
164  ArrayRef<NamedAttribute> attrs) {
165  OpBuilder builder(location->getContext());
166  OperationState state(location, getOperationName());
167  FuncOp::build(builder, state, name, type, attrs);
168  return cast<FuncOp>(Operation::create(state));
169 }
170 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
172  SmallVector<NamedAttribute, 8> attrRef(attrs);
173  return create(location, name, type, llvm::ArrayRef(attrRef));
174 }
175 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
177  ArrayRef<DictionaryAttr> argAttrs) {
178  FuncOp func = create(location, name, type, attrs);
179  func.setAllArgAttrs(argAttrs);
180  return func;
181 }
182 
183 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
184  FunctionType type, ArrayRef<NamedAttribute> attrs,
185  ArrayRef<DictionaryAttr> argAttrs) {
186  state.addAttribute(SymbolTable::getSymbolAttrName(),
187  builder.getStringAttr(name));
188  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
189  state.attributes.append(attrs.begin(), attrs.end());
190  state.addRegion();
191 
192  if (argAttrs.empty())
193  return;
194  assert(type.getNumInputs() == argAttrs.size());
196  builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
197  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
198 }
199 
200 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
201  auto buildFuncType =
202  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
204  std::string &) { return builder.getFunctionType(argTypes, results); };
205 
207  parser, result, /*allowVariadic=*/false,
208  getFunctionTypeAttrName(result.name), buildFuncType,
209  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
210 }
211 
212 void FuncOp::print(OpAsmPrinter &p) {
214  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
215  getArgAttrsAttrName(), getResAttrsAttrName());
216 }
217 
218 /// Clone the internal blocks from this function into dest and all attributes
219 /// from this function to dest.
220 void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
221  // Add the attributes of this function to dest.
222  llvm::MapVector<StringAttr, Attribute> newAttrMap;
223  for (const auto &attr : dest->getAttrs())
224  newAttrMap.insert({attr.getName(), attr.getValue()});
225  for (const auto &attr : (*this)->getAttrs())
226  newAttrMap.insert({attr.getName(), attr.getValue()});
227 
228  auto newAttrs = llvm::to_vector(llvm::map_range(
229  newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
230  return NamedAttribute(attrPair.first, attrPair.second);
231  }));
232  dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
233 
234  // Clone the body.
235  getBody().cloneInto(&dest.getBody(), mapper);
236 }
237 
238 /// Create a deep copy of this function and all of its blocks, remapping
239 /// any operands that use values outside of the function using the map that is
240 /// provided (leaving them alone if no entry is present). Replaces references
241 /// to cloned sub-values with the corresponding value that is copied, and adds
242 /// those mappings to the mapper.
243 FuncOp FuncOp::clone(IRMapping &mapper) {
244  // Create the new function.
245  FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
246 
247  // If the function has a body, then the user might be deleting arguments to
248  // the function by specifying them in the mapper. If so, we don't add the
249  // argument to the input type vector.
250  if (!isExternal()) {
251  FunctionType oldType = getFunctionType();
252 
253  unsigned oldNumArgs = oldType.getNumInputs();
254  SmallVector<Type, 4> newInputs;
255  newInputs.reserve(oldNumArgs);
256  for (unsigned i = 0; i != oldNumArgs; ++i)
257  if (!mapper.contains(getArgument(i)))
258  newInputs.push_back(oldType.getInput(i));
259 
260  /// If any of the arguments were dropped, update the type and drop any
261  /// necessary argument attributes.
262  if (newInputs.size() != oldNumArgs) {
263  newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
264  oldType.getResults()));
265 
266  if (ArrayAttr argAttrs = getAllArgAttrs()) {
267  SmallVector<Attribute> newArgAttrs;
268  newArgAttrs.reserve(newInputs.size());
269  for (unsigned i = 0; i != oldNumArgs; ++i)
270  if (!mapper.contains(getArgument(i)))
271  newArgAttrs.push_back(argAttrs[i]);
272  newFunc.setAllArgAttrs(newArgAttrs);
273  }
274  }
275  }
276 
277  /// Clone the current function into the new one and return it.
278  cloneInto(newFunc, mapper);
279  return newFunc;
280 }
281 FuncOp FuncOp::clone() {
282  IRMapping mapper;
283  return clone(mapper);
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // ReturnOp
288 //===----------------------------------------------------------------------===//
289 
290 LogicalResult ReturnOp::verify() {
291  auto function = cast<FuncOp>((*this)->getParentOp());
292 
293  // The operand number and types must match the function signature.
294  const auto &results = function.getFunctionType().getResults();
295  if (getNumOperands() != results.size())
296  return emitOpError("has ")
297  << getNumOperands() << " operands, but enclosing function (@"
298  << function.getName() << ") returns " << results.size();
299 
300  for (unsigned i = 0, e = results.size(); i != e; ++i)
301  if (getOperand(i).getType() != results[i])
302  return emitError() << "type of return operand " << i << " ("
303  << getOperand(i).getType()
304  << ") doesn't match function result type ("
305  << results[i] << ")"
306  << " in function @" << function.getName();
307 
308  return success();
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // TableGen'd op method definitions
313 //===----------------------------------------------------------------------===//
314 
315 #define GET_OP_CLASSES
316 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:204
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.