MLIR  19.0.0git
FuncOps.cpp
Go to the documentation of this file.
1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/IR/Value.h"
24 #include "llvm/ADT/APFloat.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include <numeric>
30 
31 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc"
32 
33 using namespace mlir;
34 using namespace mlir::func;
35 
36 //===----------------------------------------------------------------------===//
37 // FuncDialect
38 //===----------------------------------------------------------------------===//
39 
40 void FuncDialect::initialize() {
41  addOperations<
42 #define GET_OP_LIST
43 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
44  >();
45  declarePromisedInterface<DialectInlinerInterface, FuncDialect>();
46  declarePromisedInterface<ConvertToLLVMPatternInterface, FuncDialect>();
47  declarePromisedInterfaces<bufferization::BufferizableOpInterface, CallOp,
48  FuncOp, ReturnOp>();
49 }
50 
51 /// Materialize a single constant operation from a given attribute value with
52 /// the desired resultant type.
54  Type type, Location loc) {
55  if (ConstantOp::isBuildableWith(value, type))
56  return builder.create<ConstantOp>(loc, type,
57  llvm::cast<FlatSymbolRefAttr>(value));
58  return nullptr;
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // CallOp
63 //===----------------------------------------------------------------------===//
64 
65 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
66  // Check that the callee attribute was specified.
67  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
68  if (!fnAttr)
69  return emitOpError("requires a 'callee' symbol reference attribute");
70  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
71  if (!fn)
72  return emitOpError() << "'" << fnAttr.getValue()
73  << "' does not reference a valid function";
74 
75  // Verify that the operand and result types match the callee.
76  auto fnType = fn.getFunctionType();
77  if (fnType.getNumInputs() != getNumOperands())
78  return emitOpError("incorrect number of operands for callee");
79 
80  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
81  if (getOperand(i).getType() != fnType.getInput(i))
82  return emitOpError("operand type mismatch: expected operand type ")
83  << fnType.getInput(i) << ", but provided "
84  << getOperand(i).getType() << " for operand number " << i;
85 
86  if (fnType.getNumResults() != getNumResults())
87  return emitOpError("incorrect number of results for callee");
88 
89  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
90  if (getResult(i).getType() != fnType.getResult(i)) {
91  auto diag = emitOpError("result type mismatch at index ") << i;
92  diag.attachNote() << " op result types: " << getResultTypes();
93  diag.attachNote() << "function result types: " << fnType.getResults();
94  return diag;
95  }
96 
97  return success();
98 }
99 
100 FunctionType CallOp::getCalleeType() {
101  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // CallIndirectOp
106 //===----------------------------------------------------------------------===//
107 
108 /// Fold indirect calls that have a constant function as the callee operand.
109 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall,
110  PatternRewriter &rewriter) {
111  // Check that the callee is a constant callee.
112  SymbolRefAttr calledFn;
113  if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
114  return failure();
115 
116  // Replace with a direct call.
117  rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn,
118  indirectCall.getResultTypes(),
119  indirectCall.getArgOperands());
120  return success();
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // ConstantOp
125 //===----------------------------------------------------------------------===//
126 
128  StringRef fnName = getValue();
129  Type type = getType();
130 
131  // Try to find the referenced function.
132  auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName);
133  if (!fn)
134  return emitOpError() << "reference to undefined function '" << fnName
135  << "'";
136 
137  // Check that the referenced function has the correct type.
138  if (fn.getFunctionType() != type)
139  return emitOpError("reference to function with mismatched type");
140 
141  return success();
142 }
143 
144 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
145  return getValueAttr();
146 }
147 
148 void ConstantOp::getAsmResultNames(
149  function_ref<void(Value, StringRef)> setNameFn) {
150  setNameFn(getResult(), "f");
151 }
152 
153 bool ConstantOp::isBuildableWith(Attribute value, Type type) {
154  return llvm::isa<FlatSymbolRefAttr>(value) && llvm::isa<FunctionType>(type);
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // FuncOp
159 //===----------------------------------------------------------------------===//
160 
161 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
162  ArrayRef<NamedAttribute> attrs) {
163  OpBuilder builder(location->getContext());
164  OperationState state(location, getOperationName());
165  FuncOp::build(builder, state, name, type, attrs);
166  return cast<FuncOp>(Operation::create(state));
167 }
168 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
170  SmallVector<NamedAttribute, 8> attrRef(attrs);
171  return create(location, name, type, llvm::ArrayRef(attrRef));
172 }
173 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
175  ArrayRef<DictionaryAttr> argAttrs) {
176  FuncOp func = create(location, name, type, attrs);
177  func.setAllArgAttrs(argAttrs);
178  return func;
179 }
180 
181 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
182  FunctionType type, ArrayRef<NamedAttribute> attrs,
183  ArrayRef<DictionaryAttr> argAttrs) {
184  state.addAttribute(SymbolTable::getSymbolAttrName(),
185  builder.getStringAttr(name));
186  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
187  state.attributes.append(attrs.begin(), attrs.end());
188  state.addRegion();
189 
190  if (argAttrs.empty())
191  return;
192  assert(type.getNumInputs() == argAttrs.size());
194  builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
195  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
196 }
197 
199  auto buildFuncType =
200  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
202  std::string &) { return builder.getFunctionType(argTypes, results); };
203 
205  parser, result, /*allowVariadic=*/false,
206  getFunctionTypeAttrName(result.name), buildFuncType,
207  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
208 }
209 
210 void FuncOp::print(OpAsmPrinter &p) {
212  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
213  getArgAttrsAttrName(), getResAttrsAttrName());
214 }
215 
216 /// Clone the internal blocks from this function into dest and all attributes
217 /// from this function to dest.
218 void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
219  // Add the attributes of this function to dest.
220  llvm::MapVector<StringAttr, Attribute> newAttrMap;
221  for (const auto &attr : dest->getAttrs())
222  newAttrMap.insert({attr.getName(), attr.getValue()});
223  for (const auto &attr : (*this)->getAttrs())
224  newAttrMap.insert({attr.getName(), attr.getValue()});
225 
226  auto newAttrs = llvm::to_vector(llvm::map_range(
227  newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
228  return NamedAttribute(attrPair.first, attrPair.second);
229  }));
230  dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
231 
232  // Clone the body.
233  getBody().cloneInto(&dest.getBody(), mapper);
234 }
235 
236 /// Create a deep copy of this function and all of its blocks, remapping
237 /// any operands that use values outside of the function using the map that is
238 /// provided (leaving them alone if no entry is present). Replaces references
239 /// to cloned sub-values with the corresponding value that is copied, and adds
240 /// those mappings to the mapper.
241 FuncOp FuncOp::clone(IRMapping &mapper) {
242  // Create the new function.
243  FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
244 
245  // If the function has a body, then the user might be deleting arguments to
246  // the function by specifying them in the mapper. If so, we don't add the
247  // argument to the input type vector.
248  if (!isExternal()) {
249  FunctionType oldType = getFunctionType();
250 
251  unsigned oldNumArgs = oldType.getNumInputs();
252  SmallVector<Type, 4> newInputs;
253  newInputs.reserve(oldNumArgs);
254  for (unsigned i = 0; i != oldNumArgs; ++i)
255  if (!mapper.contains(getArgument(i)))
256  newInputs.push_back(oldType.getInput(i));
257 
258  /// If any of the arguments were dropped, update the type and drop any
259  /// necessary argument attributes.
260  if (newInputs.size() != oldNumArgs) {
261  newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
262  oldType.getResults()));
263 
264  if (ArrayAttr argAttrs = getAllArgAttrs()) {
265  SmallVector<Attribute> newArgAttrs;
266  newArgAttrs.reserve(newInputs.size());
267  for (unsigned i = 0; i != oldNumArgs; ++i)
268  if (!mapper.contains(getArgument(i)))
269  newArgAttrs.push_back(argAttrs[i]);
270  newFunc.setAllArgAttrs(newArgAttrs);
271  }
272  }
273  }
274 
275  /// Clone the current function into the new one and return it.
276  cloneInto(newFunc, mapper);
277  return newFunc;
278 }
279 FuncOp FuncOp::clone() {
280  IRMapping mapper;
281  return clone(mapper);
282 }
283 
284 //===----------------------------------------------------------------------===//
285 // ReturnOp
286 //===----------------------------------------------------------------------===//
287 
289  auto function = cast<FuncOp>((*this)->getParentOp());
290 
291  // The operand number and types must match the function signature.
292  const auto &results = function.getFunctionType().getResults();
293  if (getNumOperands() != results.size())
294  return emitOpError("has ")
295  << getNumOperands() << " operands, but enclosing function (@"
296  << function.getName() << ") returns " << results.size();
297 
298  for (unsigned i = 0, e = results.size(); i != e; ++i)
299  if (getOperand(i).getType() != results[i])
300  return emitError() << "type of return operand " << i << " ("
301  << getOperand(i).getType()
302  << ") doesn't match function result type ("
303  << results[i] << ")"
304  << " in function @" << function.getName();
305 
306  return success();
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // TableGen'd op method definitions
311 //===----------------------------------------------------------------------===//
312 
313 #define GET_OP_CLASSES
314 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
Definition: IRMapping.h:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:67
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.