MLIR  21.0.0git
CallInterfaces.cpp
Go to the documentation of this file.
1 //===- CallInterfaces.cpp - ControlFlow Interfaces ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // Argument and result attributes utilities
16 //===----------------------------------------------------------------------===//
17 
18 static ParseResult
21  // Parse individual function results.
22  return parser.parseCommaSeparatedList([&]() -> ParseResult {
23  types.emplace_back();
24  attrs.emplace_back();
25  NamedAttrList attrList;
26  if (parser.parseType(types.back()) ||
27  parser.parseOptionalAttrDict(attrList))
28  return failure();
29  attrs.back() = attrList.getDictionary(parser.getContext());
30  return success();
31  });
32 }
33 
35  OpAsmParser &parser, SmallVectorImpl<Type> &resultTypes,
36  SmallVectorImpl<DictionaryAttr> &resultAttrs) {
37  if (failed(parser.parseOptionalLParen())) {
38  // We already know that there is no `(`, so parse a type.
39  // Because there is no `(`, it cannot be a function type.
40  Type ty;
41  if (parser.parseType(ty))
42  return failure();
43  resultTypes.push_back(ty);
44  resultAttrs.emplace_back();
45  return success();
46  }
47 
48  // Special case for an empty set of parens.
49  if (succeeded(parser.parseOptionalRParen()))
50  return success();
51  if (parseTypeAndAttrList(parser, resultTypes, resultAttrs))
52  return failure();
53  return parser.parseRParen();
54 }
55 
57  OpAsmParser &parser, SmallVectorImpl<Type> &argTypes,
59  SmallVectorImpl<Type> &resultTypes,
60  SmallVectorImpl<DictionaryAttr> &resultAttrs, bool mustParseEmptyResult) {
61  // Parse arguments.
62  if (parser.parseLParen())
63  return failure();
64  if (failed(parser.parseOptionalRParen())) {
65  if (parseTypeAndAttrList(parser, argTypes, argAttrs))
66  return failure();
67  if (parser.parseRParen())
68  return failure();
69  }
70  // Parse results.
71  if (succeeded(parser.parseOptionalArrow()))
72  return call_interface_impl::parseFunctionResultList(parser, resultTypes,
73  resultAttrs);
74  if (mustParseEmptyResult)
75  return failure();
76  return success();
77 }
78 
79 /// Print a function result list. The provided `attrs` must either be null, or
80 /// contain a set of DictionaryAttrs of the same arity as `types`.
82  ArrayAttr attrs) {
83  assert(!types.empty() && "Should not be called for empty result list.");
84  assert((!attrs || attrs.size() == types.size()) &&
85  "Invalid number of attributes.");
86 
87  auto &os = p.getStream();
88  bool needsParens = types.size() > 1 || llvm::isa<FunctionType>(types[0]) ||
89  (attrs && !llvm::cast<DictionaryAttr>(attrs[0]).empty());
90  if (needsParens)
91  os << '(';
92  llvm::interleaveComma(llvm::seq<size_t>(0, types.size()), os, [&](size_t i) {
93  p.printType(types[i]);
94  if (attrs)
95  p.printOptionalAttrDict(llvm::cast<DictionaryAttr>(attrs[i]).getValue());
96  });
97  if (needsParens)
98  os << ')';
99 }
100 
102  OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic,
103  TypeRange resultTypes, ArrayAttr resultAttrs, Region *body,
104  bool printEmptyResult) {
105  bool isExternal = !body || body->empty();
106  if (!isExternal && !isVariadic && !argAttrs && !resultAttrs &&
107  printEmptyResult) {
108  p.printFunctionalType(argTypes, resultTypes);
109  return;
110  }
111 
112  p << '(';
113  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
114  if (i > 0)
115  p << ", ";
116 
117  if (!isExternal) {
119  if (argAttrs)
120  attrs = llvm::cast<DictionaryAttr>(argAttrs[i]).getValue();
121  p.printRegionArgument(body->getArgument(i), attrs);
122  } else {
123  p.printType(argTypes[i]);
124  if (argAttrs)
126  llvm::cast<DictionaryAttr>(argAttrs[i]).getValue());
127  }
128  }
129 
130  if (isVariadic) {
131  if (!argTypes.empty())
132  p << ", ";
133  p << "...";
134  }
135 
136  p << ')';
137 
138  if (!resultTypes.empty()) {
139  p << " -> ";
140  printFunctionResultList(p, resultTypes, resultAttrs);
141  } else if (printEmptyResult) {
142  p << " -> ()";
143  }
144 }
145 
147  Builder &builder, OperationState &result, ArrayRef<DictionaryAttr> argAttrs,
148  ArrayRef<DictionaryAttr> resultAttrs, StringAttr argAttrsName,
149  StringAttr resAttrsName) {
150  auto nonEmptyAttrsFn = [](DictionaryAttr attrs) {
151  return attrs && !attrs.empty();
152  };
153  // Convert the specified array of dictionary attrs (which may have null
154  // entries) to an ArrayAttr of dictionaries.
155  auto getArrayAttr = [&](ArrayRef<DictionaryAttr> dictAttrs) {
157  for (auto &dict : dictAttrs)
158  attrs.push_back(dict ? dict : builder.getDictionaryAttr({}));
159  return builder.getArrayAttr(attrs);
160  };
161 
162  // Add the attributes to the operation arguments.
163  if (llvm::any_of(argAttrs, nonEmptyAttrsFn))
164  result.addAttribute(argAttrsName, getArrayAttr(argAttrs));
165 
166  // Add the attributes to the operation results.
167  if (llvm::any_of(resultAttrs, nonEmptyAttrsFn))
168  result.addAttribute(resAttrsName, getArrayAttr(resultAttrs));
169 }
170 
172  Builder &builder, OperationState &result,
174  StringAttr argAttrsName, StringAttr resAttrsName) {
176  for (const auto &arg : args)
177  argAttrs.push_back(arg.attrs);
178  addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName,
179  resAttrsName);
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // CallOpInterface
184 //===----------------------------------------------------------------------===//
185 
186 Operation *
188  SymbolTableCollection *symbolTable) {
189  CallInterfaceCallable callable = call.getCallableForCallee();
190  if (auto symbolVal = dyn_cast<Value>(callable))
191  return symbolVal.getDefiningOp();
192 
193  // If the callable isn't a value, lookup the symbol reference.
194  auto symbolRef = cast<SymbolRefAttr>(callable);
195  if (symbolTable)
196  return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef);
197  return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef);
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // CallInterfaces
202 //===----------------------------------------------------------------------===//
203 
204 #include "mlir/Interfaces/CallInterfaces.cpp.inc"
static void printFunctionResultList(OpAsmPrinter &p, TypeRange types, ArrayAttr attrs)
Print a function result list.
static ParseResult parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl< Type > &types, SmallVectorImpl< DictionaryAttr > &attrs)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
virtual void printType(Type type)
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition: Builders.cpp:100
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic, TypeRange resultTypes, ArrayAttr resultAttrs, Region *body=nullptr, bool printEmptyResult=true)
Print a function signature for a call or callable operation.
Operation * resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable=nullptr)
Resolve the callable operation for given callee to a CallableOpInterface, or nullptr if a valid calla...
ParseResult parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parse a function or call result list.
ParseResult parseFunctionSignature(OpAsmParser &parser, SmallVectorImpl< Type > &argTypes, SmallVectorImpl< DictionaryAttr > &argAttrs, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs, bool mustParseEmptyResult=true)
Parses a function signature using parser.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
Include the generated interface declarations.
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.