MLIR  19.0.0git
FunctionImplementation.cpp
Go to the documentation of this file.
1 //===- FunctionImplementation.cpp - Utilities for function-like ops -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/SymbolTable.h"
13 
14 using namespace mlir;
15 
16 static ParseResult
17 parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic,
19  bool &isVariadic) {
20 
21  // Parse the function arguments. The argument list either has to consistently
22  // have ssa-id's followed by types, or just be a type list. It isn't ok to
23  // sometimes have SSA ID's and sometimes not.
24  isVariadic = false;
25 
26  return parser.parseCommaSeparatedList(
28  // Ellipsis must be at end of the list.
29  if (isVariadic)
30  return parser.emitError(
31  parser.getCurrentLocation(),
32  "variadic arguments must be in the end of the argument list");
33 
34  // Handle ellipsis as a special case.
35  if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) {
36  // This is a variadic designator.
37  isVariadic = true;
38  return success(); // Stop parsing arguments.
39  }
40  // Parse argument name if present.
41  OpAsmParser::Argument argument;
42  auto argPresent = parser.parseOptionalArgument(
43  argument, /*allowType=*/true, /*allowAttrs=*/true);
44  if (argPresent.has_value()) {
45  if (failed(argPresent.value()))
46  return failure(); // Present but malformed.
47 
48  // Reject this if the preceding argument was missing a name.
49  if (!arguments.empty() && arguments.back().ssaName.name.empty())
50  return parser.emitError(argument.ssaName.location,
51  "expected type instead of SSA identifier");
52 
53  } else {
54  argument.ssaName.location = parser.getCurrentLocation();
55  // Otherwise we just have a type list without SSA names. Reject
56  // this if the preceding argument had a name.
57  if (!arguments.empty() && !arguments.back().ssaName.name.empty())
58  return parser.emitError(argument.ssaName.location,
59  "expected SSA identifier");
60 
61  NamedAttrList attrs;
62  if (parser.parseType(argument.type) ||
63  parser.parseOptionalAttrDict(attrs) ||
64  parser.parseOptionalLocationSpecifier(argument.sourceLoc))
65  return failure();
66  argument.attrs = attrs.getDictionary(parser.getContext());
67  }
68  arguments.push_back(argument);
69  return success();
70  });
71 }
72 
73 /// Parse a function result list.
74 ///
75 /// function-result-list ::= function-result-list-parens
76 /// | non-function-type
77 /// function-result-list-parens ::= `(` `)`
78 /// | `(` function-result-list-no-parens `)`
79 /// function-result-list-no-parens ::= function-result (`,` function-result)*
80 /// function-result ::= type attribute-dict?
81 ///
82 static ParseResult
84  SmallVectorImpl<DictionaryAttr> &resultAttrs) {
85  if (failed(parser.parseOptionalLParen())) {
86  // We already know that there is no `(`, so parse a type.
87  // Because there is no `(`, it cannot be a function type.
88  Type ty;
89  if (parser.parseType(ty))
90  return failure();
91  resultTypes.push_back(ty);
92  resultAttrs.emplace_back();
93  return success();
94  }
95 
96  // Special case for an empty set of parens.
97  if (succeeded(parser.parseOptionalRParen()))
98  return success();
99 
100  // Parse individual function results.
101  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
102  resultTypes.emplace_back();
103  resultAttrs.emplace_back();
104  NamedAttrList attrs;
105  if (parser.parseType(resultTypes.back()) ||
106  parser.parseOptionalAttrDict(attrs))
107  return failure();
108  resultAttrs.back() = attrs.getDictionary(parser.getContext());
109  return success();
110  }))
111  return failure();
112 
113  return parser.parseRParen();
114 }
115 
117  OpAsmParser &parser, bool allowVariadic,
118  SmallVectorImpl<OpAsmParser::Argument> &arguments, bool &isVariadic,
119  SmallVectorImpl<Type> &resultTypes,
120  SmallVectorImpl<DictionaryAttr> &resultAttrs) {
121  if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic))
122  return failure();
123  if (succeeded(parser.parseOptionalArrow()))
124  return parseFunctionResultList(parser, resultTypes, resultAttrs);
125  return success();
126 }
127 
129  Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
130  ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
131  StringAttr resAttrsName) {
132  auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
133  return attrs && !attrs.empty();
134  };
135  // Convert the specified array of dictionary attrs (which may have null
136  // entries) to an ArrayAttr of dictionaries.
137  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
139  for (auto &dict : dictAttrs)
140  attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
141  return builder.getArrayAttr(attrs);
142  };
143 
144  // Add the attributes to the function arguments.
145  if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
146  result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
147 
148  // Add the attributes to the function results.
149  if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
150  result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
151 }
152 
154  Builder &builder, OperationState &result,
156  StringAttr argAttrsName, StringAttr resAttrsName) {
158  for (const auto &arg : args)
159  argAttrs.push_back(arg.attrs);
160  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
161  resAttrsName);
162 }
163 
165  OpAsmParser &parser, OperationState &result, bool allowVariadic,
166  StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder,
167  StringAttr argAttrsName, StringAttr resAttrsName) {
169  SmallVector<DictionaryAttr> resultAttrs;
170  SmallVector<Type> resultTypes;
171  auto &builder = parser.getBuilder();
172 
173  // Parse visibility.
174  (void)impl::parseOptionalVisibilityKeyword(parser, result.attributes);
175 
176  // Parse the name as a symbol.
177  StringAttr nameAttr;
178  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
179  result.attributes))
180  return failure();
181 
182  // Parse the function signature.
183  SMLoc signatureLocation = parser.getCurrentLocation();
184  bool isVariadic = false;
185  if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic,
186  resultTypes, resultAttrs))
187  return failure();
188 
189  std::string errorMessage;
190  SmallVector<Type> argTypes;
191  argTypes.reserve(entryArgs.size());
192  for (auto &arg : entryArgs)
193  argTypes.push_back(arg.type);
194  Type type = funcTypeBuilder(builder, argTypes, resultTypes,
195  VariadicFlag(isVariadic), errorMessage);
196  if (!type) {
197  return parser.emitError(signatureLocation)
198  << "failed to construct function type"
199  << (errorMessage.empty() ? "" : ": ") << errorMessage;
200  }
201  result.addAttribute(typeAttrName, TypeAttr::get(type));
202 
203  // If function attributes are present, parse them.
204  NamedAttrList parsedAttributes;
205  SMLoc attributeDictLocation = parser.getCurrentLocation();
206  if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
207  return failure();
208 
209  // Disallow attributes that are inferred from elsewhere in the attribute
210  // dictionary.
211  for (StringRef disallowed :
212  {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
213  typeAttrName.getValue()}) {
214  if (parsedAttributes.get(disallowed))
215  return parser.emitError(attributeDictLocation, "'")
216  << disallowed
217  << "' is an inferred attribute and should not be specified in the "
218  "explicit attribute dictionary";
219  }
220  result.attributes.append(parsedAttributes);
221 
222  // Add the attributes to the function arguments.
223  assert(resultAttrs.size() == resultTypes.size());
224  addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName,
225  resAttrsName);
226 
227  // Parse the optional function body. The printer will not print the body if
228  // its empty, so disallow parsing of empty body in the parser.
229  auto *body = result.addRegion();
230  SMLoc loc = parser.getCurrentLocation();
231  OptionalParseResult parseResult =
232  parser.parseOptionalRegion(*body, entryArgs,
233  /*enableNameShadowing=*/false);
234  if (parseResult.has_value()) {
235  if (failed(*parseResult))
236  return failure();
237  // Function body was parsed, make sure its not empty.
238  if (body->empty())
239  return parser.emitError(loc, "expected non-empty function body");
240  }
241  return success();
242 }
243 
244 /// Print a function result list. The provided `attrs` must either be null, or
245 /// contain a set of DictionaryAttrs of the same arity as `types`.
247  ArrayAttr attrs) {
248  assert(!types.empty() && "Should not be called for empty result list.");
249  assert((!attrs || attrs.size() == types.size()) &&
250  "Invalid number of attributes.");
251 
252  auto &os = p.getStream();
253  bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
254  (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
255  if (needsParens)
256  os << '(';
257  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
258  p.printType(types[i]);
259  if (attrs)
260  p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
261  });
262  if (needsParens)
263  os << ')';
264 }
265 
267  OpAsmPrinter &p, FunctionOpInterface op, ArrayRef<Type> argTypes,
268  bool isVariadic, ArrayRef<Type> resultTypes) {
269  Region &body = op->getRegion(0);
270  bool isExternal = body.empty();
271 
272  p << '(';
273  ArrayAttr argAttrs = op.getArgAttrsAttr();
274  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
275  if (i > 0)
276  p << ", ";
277 
278  if (!isExternal) {
280  if (argAttrs)
281  attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
282  p.printRegionArgument(body.getArgument(i), attrs);
283  } else {
284  p.printType(argTypes[i]);
285  if (argAttrs)
287  llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
288  }
289  }
290 
291  if (isVariadic) {
292  if (!argTypes.empty())
293  p << ", ";
294  p << "...";
295  }
296 
297  p << ')';
298 
299  if (!resultTypes.empty()) {
300  p.getStream() << " -> ";
301  auto resultAttrs = op.getResAttrsAttr();
302  printFunctionResultList(p, resultTypes, resultAttrs);
303  }
304 }
305 
307  OpAsmPrinter &p, Operation *op, ArrayRef<StringRef> elided) {
308  // Print out function attributes, if present.
309  SmallVector<StringRef, 8> ignoredAttrs = {SymbolTable::getSymbolAttrName()};
310  ignoredAttrs.append(elided.begin(), elided.end());
311 
312  p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
313 }
314 
316  OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic,
317  StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName) {
318  // Print the operation and the function name.
319  auto funcName =
320  op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
321  .getValue();
322  p << ' ';
323 
324  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
325  if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
326  p << visibility.getValue() << ' ';
327  p.printSymbolName(funcName);
328 
329  ArrayRef<Type> argTypes = op.getArgumentTypes();
330  ArrayRef<Type> resultTypes = op.getResultTypes();
331  printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
333  p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName});
334  // Print the body if this is not an external function.
335  Region &body = op->getRegion(0);
336  if (!body.empty()) {
337  p << ' ';
338  p.printRegion(body, /*printEntryBlockArgs=*/false,
339  /*printBlockTerminators=*/true);
340  }
341 }
static void printFunctionResultList(OpAsmPrinter &p, ArrayRef< Type > types, ArrayAttr attrs)
Print a function result list.
static ParseResult parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic)
static ParseResult parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parse a function result list.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalEllipsis()=0
Parse a ... token if present;.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:120
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument if present.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:545
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_type_range getResultTypes()
Definition: Operation.h:423
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
ParseResult parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, NamedAttrList &attrs)
Parse an optional visibility attribute keyword (i.e., public, private, or nested) without quotes in a...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.