MLIR  14.0.0git
Async.cpp
Go to the documentation of this file.
1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
18 
19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
20 
21 void AsyncDialect::initialize() {
22  addOperations<
23 #define GET_OP_LIST
24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
25  >();
26  addTypes<
27 #define GET_TYPEDEF_LIST
28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
29  >();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // YieldOp
34 //===----------------------------------------------------------------------===//
35 
36 static LogicalResult verify(YieldOp op) {
37  // Get the underlying value types from async values returned from the
38  // parent `async.execute` operation.
39  auto executeOp = op->getParentOfType<ExecuteOp>();
40  auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
41  return result.getType().cast<ValueType>().getValueType();
42  });
43 
44  if (op.getOperandTypes() != types)
45  return op.emitOpError("operand types do not match the types returned from "
46  "the parent ExecuteOp");
47 
48  return success();
49 }
50 
52 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
53  assert(!index.hasValue());
54  return operandsMutable();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 /// ExecuteOp
59 //===----------------------------------------------------------------------===//
60 
61 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
62 
63 OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) {
64  assert(index == 0 && "invalid region index");
65  return operands();
66 }
67 
68 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
71  // The `body` region branch back to the parent operation.
72  if (index.hasValue()) {
73  assert(*index == 0 && "invalid region index");
74  regions.push_back(RegionSuccessor(results()));
75  return;
76  }
77 
78  // Otherwise the successor is the body region.
79  regions.push_back(RegionSuccessor(&body(), body().getArguments()));
80 }
81 
82 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
83  TypeRange resultTypes, ValueRange dependencies,
84  ValueRange operands, BodyBuilderFn bodyBuilder) {
85 
86  result.addOperands(dependencies);
87  result.addOperands(operands);
88 
89  // Add derived `operand_segment_sizes` attribute based on parsed operands.
90  int32_t numDependencies = dependencies.size();
91  int32_t numOperands = operands.size();
92  auto operandSegmentSizes = DenseIntElementsAttr::get(
93  VectorType::get({2}, builder.getIntegerType(32)),
94  {numDependencies, numOperands});
95  result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
96 
97  // First result is always a token, and then `resultTypes` wrapped into
98  // `async.value`.
99  result.addTypes({TokenType::get(result.getContext())});
100  for (Type type : resultTypes)
101  result.addTypes(ValueType::get(type));
102 
103  // Add a body region with block arguments as unwrapped async value operands.
104  Region *bodyRegion = result.addRegion();
105  bodyRegion->push_back(new Block);
106  Block &bodyBlock = bodyRegion->front();
107  for (Value operand : operands) {
108  auto valueType = operand.getType().dyn_cast<ValueType>();
109  bodyBlock.addArgument(valueType ? valueType.getValueType()
110  : operand.getType(),
111  operand.getLoc());
112  }
113 
114  // Create the default terminator if the builder is not provided and if the
115  // expected result is empty. Otherwise, leave this to the caller
116  // because we don't know which values to return from the execute op.
117  if (resultTypes.empty() && !bodyBuilder) {
118  OpBuilder::InsertionGuard guard(builder);
119  builder.setInsertionPointToStart(&bodyBlock);
120  builder.create<async::YieldOp>(result.location, ValueRange());
121  } else if (bodyBuilder) {
122  OpBuilder::InsertionGuard guard(builder);
123  builder.setInsertionPointToStart(&bodyBlock);
124  bodyBuilder(builder, result.location, bodyBlock.getArguments());
125  }
126 }
127 
128 static void print(OpAsmPrinter &p, ExecuteOp op) {
129  // [%tokens,...]
130  if (!op.dependencies().empty())
131  p << " [" << op.dependencies() << "]";
132 
133  // (%value as %unwrapped: !async.value<!arg.type>, ...)
134  if (!op.operands().empty()) {
135  p << " (";
136  Block *entry = op.body().empty() ? nullptr : &op.body().front();
137  llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
138  Value argument = entry ? entry->getArgument(n++) : Value();
139  p << operand << " as " << argument << ": " << operand.getType();
140  });
141  p << ")";
142  }
143 
144  // -> (!async.value<!return.type>, ...)
145  p.printOptionalArrowTypeList(llvm::drop_begin(op.getResultTypes()));
146  p.printOptionalAttrDictWithKeyword(op->getAttrs(),
148  p << ' ';
149  p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
150 }
151 
153  MLIRContext *ctx = result.getContext();
154 
155  // Sizes of parsed variadic operands, will be updated below after parsing.
156  int32_t numDependencies = 0;
157 
158  auto tokenTy = TokenType::get(ctx);
159 
160  // Parse dependency tokens.
161  if (succeeded(parser.parseOptionalLSquare())) {
163  if (parser.parseOperandList(tokenArgs) ||
164  parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
165  parser.parseRSquare())
166  return failure();
167 
168  numDependencies = tokenArgs.size();
169  }
170 
171  // Parse async value operands (%value as %unwrapped : !async.value<!type>).
174  SmallVector<Type, 4> valueTypes;
175  SmallVector<Type, 4> unwrappedTypes;
176 
177  // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
178  auto parseAsyncValueArg = [&]() -> ParseResult {
179  if (parser.parseOperand(valueArgs.emplace_back()) ||
180  parser.parseKeyword("as") ||
181  parser.parseOperand(unwrappedArgs.emplace_back()) ||
182  parser.parseColonType(valueTypes.emplace_back()))
183  return failure();
184 
185  auto valueTy = valueTypes.back().dyn_cast<ValueType>();
186  unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
187 
188  return success();
189  };
190 
191  auto argsLoc = parser.getCurrentLocation();
193  parseAsyncValueArg) ||
194  parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
195  return failure();
196 
197  int32_t numOperands = valueArgs.size();
198 
199  // Add derived `operand_segment_sizes` attribute based on parsed operands.
200  auto operandSegmentSizes = DenseIntElementsAttr::get(
201  VectorType::get({2}, parser.getBuilder().getI32Type()),
202  {numDependencies, numOperands});
203  result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
204 
205  // Parse the types of results returned from the async execute op.
206  SmallVector<Type, 4> resultTypes;
207  if (parser.parseOptionalArrowTypeList(resultTypes))
208  return failure();
209 
210  // Async execute first result is always a completion token.
211  parser.addTypeToList(tokenTy, result.types);
212  parser.addTypesToList(resultTypes, result.types);
213 
214  // Parse operation attributes.
215  NamedAttrList attrs;
216  if (parser.parseOptionalAttrDictWithKeyword(attrs))
217  return failure();
218  result.addAttributes(attrs);
219 
220  // Parse asynchronous region.
221  Region *body = result.addRegion();
222  if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
223  /*argTypes=*/{unwrappedTypes},
224  /*argLocations=*/{},
225  /*enableNameShadowing=*/false))
226  return failure();
227 
228  return success();
229 }
230 
231 static LogicalResult verify(ExecuteOp op) {
232  // Unwrap async.execute value operands types.
233  auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
234  return operand.getType().cast<ValueType>().getValueType();
235  });
236 
237  // Verify that unwrapped argument types matches the body region arguments.
238  if (op.body().getArgumentTypes() != unwrappedTypes)
239  return op.emitOpError("async body region argument types do not match the "
240  "execute operation arguments types");
241 
242  return success();
243 }
244 
245 //===----------------------------------------------------------------------===//
246 /// CreateGroupOp
247 //===----------------------------------------------------------------------===//
248 
249 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
250  PatternRewriter &rewriter) {
251  // Find all `await_all` users of the group.
252  llvm::SmallVector<AwaitAllOp> awaitAllUsers;
253 
254  auto isAwaitAll = [&](Operation *op) -> bool {
255  if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
256  awaitAllUsers.push_back(awaitAll);
257  return true;
258  }
259  return false;
260  };
261 
262  // Check if all users of the group are `await_all` operations.
263  if (!llvm::all_of(op->getUsers(), isAwaitAll))
264  return failure();
265 
266  // If group is only awaited without adding anything to it, we can safely erase
267  // the create operation and all users.
268  for (AwaitAllOp awaitAll : awaitAllUsers)
269  rewriter.eraseOp(awaitAll);
270  rewriter.eraseOp(op);
271 
272  return success();
273 }
274 
275 //===----------------------------------------------------------------------===//
276 /// AwaitOp
277 //===----------------------------------------------------------------------===//
278 
279 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
280  ArrayRef<NamedAttribute> attrs) {
281  result.addOperands({operand});
282  result.attributes.append(attrs.begin(), attrs.end());
283 
284  // Add unwrapped async.value type to the returned values types.
285  if (auto valueType = operand.getType().dyn_cast<ValueType>())
286  result.addTypes(valueType.getValueType());
287 }
288 
290  Type &resultType) {
291  if (parser.parseType(operandType))
292  return failure();
293 
294  // Add unwrapped async.value type to the returned values types.
295  if (auto valueType = operandType.dyn_cast<ValueType>())
296  resultType = valueType.getValueType();
297 
298  return success();
299 }
300 
302  Type operandType, Type resultType) {
303  p << operandType;
304 }
305 
306 static LogicalResult verify(AwaitOp op) {
307  Type argType = op.operand().getType();
308 
309  // Awaiting on a token does not have any results.
310  if (argType.isa<TokenType>() && !op.getResultTypes().empty())
311  return op.emitOpError("awaiting on a token must have empty result");
312 
313  // Awaiting on a value unwraps the async value type.
314  if (auto value = argType.dyn_cast<ValueType>()) {
315  if (*op.getResultType() != value.getValueType())
316  return op.emitOpError()
317  << "result type " << *op.getResultType()
318  << " does not match async value type " << value.getValueType();
319  }
320 
321  return success();
322 }
323 
324 //===----------------------------------------------------------------------===//
325 // TableGen'd op method definitions
326 //===----------------------------------------------------------------------===//
327 
328 #define GET_OP_CLASSES
329 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
330 
331 //===----------------------------------------------------------------------===//
332 // TableGen'd type method definitions
333 //===----------------------------------------------------------------------===//
334 
335 #define GET_TYPEDEF_CLASSES
336 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
337 
338 void ValueType::print(AsmPrinter &printer) const {
339  printer << "<";
340  printer.printType(getValueType());
341  printer << '>';
342 }
343 
344 Type ValueType::parse(mlir::AsmParser &parser) {
345  Type ty;
346  if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
347  parser.emitError(parser.getNameLoc(), "failed to parse async value type");
348  return Type();
349  }
350  return ValueType::get(ty);
351 }
virtual ParseResult parseOperand(OperandType &result)=0
Parse a single operand.
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
ParseResult resolveOperands(ArrayRef< OperandType > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This is a value defined by a result of an operation.
Definition: Value.h:423
virtual void printType(Type type)
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
LogicalResult verify(Operation *op)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:353
void push_back(Block *block)
Definition: Region.h:61
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static void printAwaitResultType(OpAsmPrinter &p, Operation *op, Type operandType, Type resultType)
Definition: Async.cpp:301
constexpr char kOperandSegmentSizesAttr[]
ExecuteOp.
Definition: Async.cpp:61
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
static constexpr const bool value
SmallVector< Value, 4 > operands
virtual ParseResult parseOperandList(SmallVectorImpl< OperandType > &result, int requiredOperandCount=-1, Delimiter delimiter=Delimiter::None)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual ParseResult parseRegion(Region &region, ArrayRef< OperandType > arguments={}, ArrayRef< Type > argTypes={}, ArrayRef< Location > argLocations={}, bool enableNameShadowing=false)=0
Parses a region.
void addOperands(ValueRange newOperands)
virtual llvm::SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
U dyn_cast() const
Definition: Types.h:244
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, Type &resultType)
Definition: Async.cpp:289
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Parens supporting zero or more operands, or nothing.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
void addTypes(ArrayRef< Type > newTypes)
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class provides a mutable adaptor for a range of operands.
This represents an operation in an abstracted form, suitable for use with the builder APIs...
static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result)
Definition: Async.cpp:152
BlockArgListType getArguments()
Definition: Block.h:76
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if the attributes keyword is present.
bool empty()
Definition: Block.h:139
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
MLIRContext * getContext() const
Get the context held by this operation state.
NamedAttrList attributes
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:362
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:279
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
Type getType() const
Return the type of this value.
Definition: Value.h:117
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
virtual ParseResult parseType(Type &result)=0
Parse a type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with &#39;attribute...
This class implements the operand iterators for the Operation class.
This base class exposes generic asm printer hooks, usable across the various derived printers...
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
bool isa() const
Definition: Types.h:234
static void print(OpAsmPrinter &p, ExecuteOp op)
Definition: Async.cpp:128
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
This class represents success/failure for operation parsing.
Definition: OpDefinition.h:36
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success. ...
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
IntegerType getI32Type()
Definition: Builders.cpp:54
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.