MLIR  14.0.0git
FunctionImplementation.cpp
Go to the documentation of this file.
1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/SymbolTable.h"
13 
14 using namespace mlir;
15 
17  OpAsmParser &parser, bool allowAttributes, bool allowVariadic,
20  SmallVectorImpl<Location> &argLocations, bool &isVariadic) {
21  if (parser.parseLParen())
22  return failure();
23 
24  // The argument list either has to consistently have ssa-id's followed by
25  // types, or just be a type list. It isn't ok to sometimes have SSA ID's and
26  // sometimes not.
27  auto parseArgument = [&]() -> ParseResult {
28  llvm::SMLoc loc = parser.getCurrentLocation();
29 
30  // Parse argument name if present.
31  OpAsmParser::OperandType argument;
32  Type argumentType;
33  if (succeeded(parser.parseOptionalRegionArgument(argument)) &&
34  !argument.name.empty()) {
35  // Reject this if the preceding argument was missing a name.
36  if (argNames.empty() && !argTypes.empty())
37  return parser.emitError(loc, "expected type instead of SSA identifier");
38  argNames.push_back(argument);
39 
40  if (parser.parseColonType(argumentType))
41  return failure();
42  } else if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
43  isVariadic = true;
44  return success();
45  } else if (!argNames.empty()) {
46  // Reject this if the preceding argument had a name.
47  return parser.emitError(loc, "expected SSA identifier");
48  } else if (parser.parseType(argumentType)) {
49  return failure();
50  }
51 
52  // Add the argument type.
53  argTypes.push_back(argumentType);
54 
55  // Parse any argument attributes.
56  NamedAttrList attrs;
57  if (parser.parseOptionalAttrDict(attrs))
58  return failure();
59  if (!allowAttributes && !attrs.empty())
60  return parser.emitError(loc, "expected arguments without attributes");
61  argAttrs.push_back(attrs);
62 
63  // Parse a location if specified.
64  Optional<Location> explicitLoc;
65  if (!argument.name.empty() &&
66  parser.parseOptionalLocationSpecifier(explicitLoc))
67  return failure();
68  if (!explicitLoc)
69  explicitLoc = parser.getEncodedSourceLoc(loc);
70  argLocations.push_back(*explicitLoc);
71 
72  return success();
73  };
74 
75  // Parse the function arguments.
76  isVariadic = false;
77  if (failed(parser.parseOptionalRParen())) {
78  do {
79  unsigned numTypedArguments = argTypes.size();
80  if (parseArgument())
81  return failure();
82 
83  llvm::SMLoc loc = parser.getCurrentLocation();
84  if (argTypes.size() == numTypedArguments &&
85  succeeded(parser.parseOptionalComma()))
86  return parser.emitError(
87  loc, "variadic arguments must be in the end of the argument list");
88  } while (succeeded(parser.parseOptionalComma()));
89  parser.parseRParen();
90  }
91 
92  return success();
93 }
94 
95 /// Parse a function result list.
96 ///
97 /// function-result-list ::= function-result-list-parens
98 /// | non-function-type
99 /// function-result-list-parens ::= `(` `)`
100 /// | `(` function-result-list-no-parens `)`
101 /// function-result-list-no-parens ::= function-result (`,` function-result)*
102 /// function-result ::= type attribute-dict?
103 ///
104 static ParseResult
106  SmallVectorImpl<NamedAttrList> &resultAttrs) {
107  if (failed(parser.parseOptionalLParen())) {
108  // We already know that there is no `(`, so parse a type.
109  // Because there is no `(`, it cannot be a function type.
110  Type ty;
111  if (parser.parseType(ty))
112  return failure();
113  resultTypes.push_back(ty);
114  resultAttrs.emplace_back();
115  return success();
116  }
117 
118  // Special case for an empty set of parens.
119  if (succeeded(parser.parseOptionalRParen()))
120  return success();
121 
122  // Parse individual function results.
123  do {
124  resultTypes.emplace_back();
125  resultAttrs.emplace_back();
126  if (parser.parseType(resultTypes.back()) ||
127  parser.parseOptionalAttrDict(resultAttrs.back())) {
128  return failure();
129  }
130  } while (succeeded(parser.parseOptionalComma()));
131  return parser.parseRParen();
132 }
133 
135  OpAsmParser &parser, bool allowVariadic,
138  SmallVectorImpl<Location> &argLocations, bool &isVariadic,
139  SmallVectorImpl<Type> &resultTypes,
140  SmallVectorImpl<NamedAttrList> &resultAttrs) {
141  bool allowArgAttrs = true;
142  if (parseFunctionArgumentList(parser, allowArgAttrs, allowVariadic, argNames,
143  argTypes, argAttrs, argLocations, isVariadic))
144  return failure();
145  if (succeeded(parser.parseOptionalArrow()))
146  return parseFunctionResultList(parser, resultTypes, resultAttrs);
147  return success();
148 }
149 
150 /// Implementation of `addArgAndResultAttrs` that is attribute list type
151 /// agnostic.
152 template <typename AttrListT, typename AttrArrayBuildFnT>
153 static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result,
154  ArrayRef<AttrListT> argAttrs,
155  ArrayRef<AttrListT> resultAttrs,
156  AttrArrayBuildFnT &&buildAttrArrayFn) {
157  auto nonEmptyAttrsFn = [](const AttrListT &attrs) { return !attrs.empty(); };
158 
159  // Add the attributes to the function arguments.
160  if (!argAttrs.empty() && llvm::any_of(argAttrs, nonEmptyAttrsFn)) {
161  ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(argAttrs));
163  attrDicts);
164  }
165  // Add the attributes to the function results.
166  if (!resultAttrs.empty() && llvm::any_of(resultAttrs, nonEmptyAttrsFn)) {
167  ArrayAttr attrDicts = builder.getArrayAttr(buildAttrArrayFn(resultAttrs));
169  attrDicts);
170  }
171 }
172 
174  Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
175  ArrayRef<DictionaryAttr> resultAttrs) {
176  auto buildFn = [](ArrayRef<DictionaryAttr> attrs) {
177  return ArrayRef<Attribute>(attrs.data(), attrs.size());
178  };
179  addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
180 }
182  Builder &builder, OperationState &result, ArrayRef<NamedAttrList> argAttrs,
183  ArrayRef<NamedAttrList> resultAttrs) {
184  MLIRContext *context = builder.getContext();
185  auto buildFn = [=](ArrayRef<NamedAttrList> attrs) {
186  return llvm::to_vector<8>(
187  llvm::map_range(attrs, [=](const NamedAttrList &attrList) -> Attribute {
188  return attrList.getDictionary(context);
189  }));
190  };
191  addArgAndResultAttrsImpl(builder, result, argAttrs, resultAttrs, buildFn);
192 }
193 
195  OpAsmParser &parser, OperationState &result, bool allowVariadic,
196  FuncTypeBuilder funcTypeBuilder) {
199  SmallVector<NamedAttrList> resultAttrs;
200  SmallVector<Type> argTypes;
201  SmallVector<Type> resultTypes;
202  SmallVector<Location> argLocations;
203  auto &builder = parser.getBuilder();
204 
205  // Parse visibility.
207 
208  // Parse the name as a symbol.
209  StringAttr nameAttr;
210  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
211  result.attributes))
212  return failure();
213 
214  // Parse the function signature.
215  llvm::SMLoc signatureLocation = parser.getCurrentLocation();
216  bool isVariadic = false;
217  if (parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes,
218  argAttrs, argLocations, isVariadic, resultTypes,
219  resultAttrs))
220  return failure();
221 
222  std::string errorMessage;
223  Type type = funcTypeBuilder(builder, argTypes, resultTypes,
224  VariadicFlag(isVariadic), errorMessage);
225  if (!type) {
226  return parser.emitError(signatureLocation)
227  << "failed to construct function type"
228  << (errorMessage.empty() ? "" : ": ") << errorMessage;
229  }
230  result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
231 
232  // If function attributes are present, parse them.
233  NamedAttrList parsedAttributes;
234  llvm::SMLoc attributeDictLocation = parser.getCurrentLocation();
235  if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
236  return failure();
237 
238  // Disallow attributes that are inferred from elsewhere in the attribute
239  // dictionary.
240  for (StringRef disallowed :
242  getTypeAttrName()}) {
243  if (parsedAttributes.get(disallowed))
244  return parser.emitError(attributeDictLocation, "'")
245  << disallowed
246  << "' is an inferred attribute and should not be specified in the "
247  "explicit attribute dictionary";
248  }
249  result.attributes.append(parsedAttributes);
250 
251  // Add the attributes to the function arguments.
252  assert(argAttrs.size() == argTypes.size());
253  assert(resultAttrs.size() == resultTypes.size());
254  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
255 
256  // Parse the optional function body. The printer will not print the body if
257  // its empty, so disallow parsing of empty body in the parser.
258  auto *body = result.addRegion();
259  llvm::SMLoc loc = parser.getCurrentLocation();
260  OptionalParseResult parseResult = parser.parseOptionalRegion(
261  *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes,
262  entryArgs.empty() ? ArrayRef<Location>() : argLocations,
263  /*enableNameShadowing=*/false);
264  if (parseResult.hasValue()) {
265  if (failed(*parseResult))
266  return failure();
267  // Function body was parsed, make sure its not empty.
268  if (body->empty())
269  return parser.emitError(loc, "expected non-empty function body");
270  }
271  return success();
272 }
273 
274 /// Print a function result list. The provided `attrs` must either be null, or
275 /// contain a set of DictionaryAttrs of the same arity as `types`.
277  ArrayAttr attrs) {
278  assert(!types.empty() && "Should not be called for empty result list.");
279  assert((!attrs || attrs.size() == types.size()) &&
280  "Invalid number of attributes.");
281 
282  auto &os = p.getStream();
283  bool needsParens = types.size() > 1 || types[0].isa<FunctionType>() ||
284  (attrs && !attrs[0].cast<DictionaryAttr>().empty());
285  if (needsParens)
286  os << '(';
287  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
288  p.printType(types[i]);
289  if (attrs)
290  p.printOptionalAttrDict(attrs[i].cast<DictionaryAttr>().getValue());
291  });
292  if (needsParens)
293  os << ')';
294 }
295 
297  OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
298  ArrayRef<Type> resultTypes) {
299  Region &body = op->getRegion(0);
300  bool isExternal = body.empty();
301 
302  p << '(';
303  ArrayAttr argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
304  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
305  if (i > 0)
306  p << ", ";
307 
308  if (!isExternal) {
310  if (argAttrs)
311  attrs = argAttrs[i].cast<DictionaryAttr>().getValue();
312  p.printRegionArgument(body.getArgument(i), attrs);
313  } else {
314  p.printType(argTypes[i]);
315  if (argAttrs)
316  p.printOptionalAttrDict(argAttrs[i].cast<DictionaryAttr>().getValue());
317  }
318  }
319 
320  if (isVariadic) {
321  if (!argTypes.empty())
322  p << ", ";
323  p << "...";
324  }
325 
326  p << ')';
327 
328  if (!resultTypes.empty()) {
329  p.getStream() << " -> ";
330  auto resultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
331  printFunctionResultList(p, resultTypes, resultAttrs);
332  }
333 }
334 
336  OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults,
337  ArrayRef<StringRef> elided) {
338  // Print out function attributes, if present.
339  SmallVector<StringRef, 2> ignoredAttrs = {
342  ignoredAttrs.append(elided.begin(), elided.end());
343 
344  p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
345 }
346 
348  OpAsmPrinter &p, Operation *op, ArrayRef<Type> argTypes, bool isVariadic,
349  ArrayRef<Type> resultTypes) {
350  // Print the operation and the function name.
351  auto funcName =
353  .getValue();
354  p << ' ';
355 
356  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
357  if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
358  p << visibility.getValue() << ' ';
359  p.printSymbolName(funcName);
360 
361  printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
362  printFunctionAttributes(p, op, argTypes.size(), resultTypes.size(),
363  {visibilityAttrName});
364  // Print the body if this is not an external function.
365  Region &body = op->getRegion(0);
366  if (!body.empty()) {
367  p << ' ';
368  p.printRegion(body, /*printEntryBlockArgs=*/false,
369  /*printBlockTerminators=*/true);
370  }
371 }
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
This is the representation of an operand reference.
void printFunctionSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Include the generated interface declarations.
StringRef getResultDictAttrName()
Return the name of the attribute used for function argument attributes.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:55
virtual ParseResult parseLParen()=0
Parse a ( token.
MLIRContext * getContext() const
Definition: Builders.h:54
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
StringRef getArgDictAttrName()
Return the name of the attribute used for function argument attributes.
virtual void printType(Type type)
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:308
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:327
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments, to the list of operation attributes in result.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
static void addArgAndResultAttrsImpl(Builder &builder, OperationState &result, ArrayRef< AttrListT > argAttrs, ArrayRef< AttrListT > resultAttrs, AttrArrayBuildFnT &&buildAttrArrayFn)
Implementation of addArgAndResultAttrs that is attribute list type agnostic.
static StringRef getVisibilityAttrName()
Return the name of the attribute used for symbol visibility.
Definition: SymbolTable.h:61
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
virtual ParseResult parseOptionalRegionArgument(OperandType &argument)=0
Parse a region argument if present.
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
A named class for passing around the variadic flag.
virtual Location getEncodedSourceLoc(llvm::SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
static void printFunctionResultList(OpAsmPrinter &p, ArrayRef< Type > types, ArrayAttr attrs)
Print a function result list.
ParseResult parseSymbolName(StringAttr &result, StringRef attrName, NamedAttrList &attrs)
Parse an -identifier and store it (without the &#39;@&#39; symbol) in a string attribute named &#39;attrName&#39;...
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::OperandType > &argNames, SmallVectorImpl< Type > &argTypes, SmallVectorImpl< NamedAttrList > &argAttrs, SmallVectorImpl< Location > &argLocations, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< NamedAttrList > &resultAttrs)
Parses a function signature using parser.
ParseResult parseFunctionArgumentList(OpAsmParser &parser, bool allowAttributes, bool allowVariadic, SmallVectorImpl< OpAsmParser::OperandType > &argNames, SmallVectorImpl< Type > &argTypes, SmallVectorImpl< NamedAttrList > &argAttrs, SmallVectorImpl< Location > &argLocations, bool &isVariadic)
Parses function arguments using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs, unsigned numResults, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
bool empty()
Definition: Region.h:60
virtual llvm::SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< OperandType > arguments={}, ArrayRef< Type > argTypes={}, ArrayRef< Location > argLocations={}, bool enableNameShadowing=false)=0
Parses a region if present.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
virtual ParseResult parseRParen()=0
Parse a ) token.
void printFunctionOp(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Printer implementation for function-like operations.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static ParseResult parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< NamedAttrList > &resultAttrs)
Parse a function result list.
This represents an operation in an abstracted form, suitable for use with the builder APIs...
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
virtual ParseResult parseOptionalLocationSpecifier(Optional< Location > &result)=0
Parse a loc(...) specifier if present, filling in result if so.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if the attributes keyword is present.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:52
virtual ParseResult parseOptionalArrow()=0
Parse a &#39;->&#39; token if present.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
NamedAttrList attributes
virtual InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
Region * addRegion()
Create a region that should be attached to the operation.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
virtual ParseResult parseType(Type &result)=0
Parse a type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with &#39;attribute...
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, FuncTypeBuilder funcTypeBuilder)
Parser implementation for function-like operations.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
bool hasValue() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:62
virtual ParseResult parseOptionalEllipsis()=0
Parse a ... token if present;.
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:429
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.