MLIR  16.0.0git
Async.cpp
Go to the documentation of this file.
1 //===- Async.cpp - MLIR Async Operations ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::async;
16 
17 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
18 
19 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
20 
21 void AsyncDialect::initialize() {
22  addOperations<
23 #define GET_OP_LIST
24 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
25  >();
26  addTypes<
27 #define GET_TYPEDEF_LIST
28 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
29  >();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // YieldOp
34 //===----------------------------------------------------------------------===//
35 
37  // Get the underlying value types from async values returned from the
38  // parent `async.execute` operation.
39  auto executeOp = (*this)->getParentOfType<ExecuteOp>();
40  auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
41  return result.getType().cast<ValueType>().getValueType();
42  });
43 
44  if (getOperandTypes() != types)
45  return emitOpError("operand types do not match the types returned from "
46  "the parent ExecuteOp");
47 
48  return success();
49 }
50 
52 YieldOp::getMutableSuccessorOperands(Optional<unsigned> index) {
53  return operandsMutable();
54 }
55 
56 //===----------------------------------------------------------------------===//
57 /// ExecuteOp
58 //===----------------------------------------------------------------------===//
59 
60 constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
61 
62 OperandRange ExecuteOp::getSuccessorEntryOperands(Optional<unsigned> index) {
63  assert(index && *index == 0 && "invalid region index");
64  return operands();
65 }
66 
67 bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
68  const auto getValueOrTokenType = [](Type type) {
69  if (auto value = type.dyn_cast<ValueType>())
70  return value.getValueType();
71  return type;
72  };
73  return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
74 }
75 
76 void ExecuteOp::getSuccessorRegions(Optional<unsigned> index,
79  // The `body` region branch back to the parent operation.
80  if (index) {
81  assert(*index == 0 && "invalid region index");
82  regions.push_back(RegionSuccessor(results()));
83  return;
84  }
85 
86  // Otherwise the successor is the body region.
87  regions.push_back(RegionSuccessor(&body(), body().getArguments()));
88 }
89 
90 void ExecuteOp::build(OpBuilder &builder, OperationState &result,
91  TypeRange resultTypes, ValueRange dependencies,
92  ValueRange operands, BodyBuilderFn bodyBuilder) {
93 
94  result.addOperands(dependencies);
95  result.addOperands(operands);
96 
97  // Add derived `operand_segment_sizes` attribute based on parsed operands.
98  int32_t numDependencies = dependencies.size();
99  int32_t numOperands = operands.size();
100  auto operandSegmentSizes = DenseIntElementsAttr::get(
101  VectorType::get({2}, builder.getIntegerType(32)),
102  {numDependencies, numOperands});
103  result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
104 
105  // First result is always a token, and then `resultTypes` wrapped into
106  // `async.value`.
107  result.addTypes({TokenType::get(result.getContext())});
108  for (Type type : resultTypes)
109  result.addTypes(ValueType::get(type));
110 
111  // Add a body region with block arguments as unwrapped async value operands.
112  Region *bodyRegion = result.addRegion();
113  bodyRegion->push_back(new Block);
114  Block &bodyBlock = bodyRegion->front();
115  for (Value operand : operands) {
116  auto valueType = operand.getType().dyn_cast<ValueType>();
117  bodyBlock.addArgument(valueType ? valueType.getValueType()
118  : operand.getType(),
119  operand.getLoc());
120  }
121 
122  // Create the default terminator if the builder is not provided and if the
123  // expected result is empty. Otherwise, leave this to the caller
124  // because we don't know which values to return from the execute op.
125  if (resultTypes.empty() && !bodyBuilder) {
126  OpBuilder::InsertionGuard guard(builder);
127  builder.setInsertionPointToStart(&bodyBlock);
128  builder.create<async::YieldOp>(result.location, ValueRange());
129  } else if (bodyBuilder) {
130  OpBuilder::InsertionGuard guard(builder);
131  builder.setInsertionPointToStart(&bodyBlock);
132  bodyBuilder(builder, result.location, bodyBlock.getArguments());
133  }
134 }
135 
137  // [%tokens,...]
138  if (!dependencies().empty())
139  p << " [" << dependencies() << "]";
140 
141  // (%value as %unwrapped: !async.value<!arg.type>, ...)
142  if (!operands().empty()) {
143  p << " (";
144  Block *entry = body().empty() ? nullptr : &body().front();
145  llvm::interleaveComma(operands(), p, [&, n = 0](Value operand) mutable {
146  Value argument = entry ? entry->getArgument(n++) : Value();
147  p << operand << " as " << argument << ": " << operand.getType();
148  });
149  p << ")";
150  }
151 
152  // -> (!async.value<!return.type>, ...)
153  p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
154  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
155  {kOperandSegmentSizesAttr});
156  p << ' ';
157  p.printRegion(body(), /*printEntryBlockArgs=*/false);
158 }
159 
160 ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
161  MLIRContext *ctx = result.getContext();
162 
163  // Sizes of parsed variadic operands, will be updated below after parsing.
164  int32_t numDependencies = 0;
165 
166  auto tokenTy = TokenType::get(ctx);
167 
168  // Parse dependency tokens.
169  if (succeeded(parser.parseOptionalLSquare())) {
171  if (parser.parseOperandList(tokenArgs) ||
172  parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
173  parser.parseRSquare())
174  return failure();
175 
176  numDependencies = tokenArgs.size();
177  }
178 
179  // Parse async value operands (%value as %unwrapped : !async.value<!type>).
182  SmallVector<Type, 4> valueTypes;
183 
184  // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
185  auto parseAsyncValueArg = [&]() -> ParseResult {
186  if (parser.parseOperand(valueArgs.emplace_back()) ||
187  parser.parseKeyword("as") ||
188  parser.parseArgument(unwrappedArgs.emplace_back()) ||
189  parser.parseColonType(valueTypes.emplace_back()))
190  return failure();
191 
192  auto valueTy = valueTypes.back().dyn_cast<ValueType>();
193  unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
194  return success();
195  };
196 
197  auto argsLoc = parser.getCurrentLocation();
199  parseAsyncValueArg) ||
200  parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
201  return failure();
202 
203  int32_t numOperands = valueArgs.size();
204 
205  // Add derived `operand_segment_sizes` attribute based on parsed operands.
206  auto operandSegmentSizes = DenseIntElementsAttr::get(
207  VectorType::get({2}, parser.getBuilder().getI32Type()),
208  {numDependencies, numOperands});
209  result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
210 
211  // Parse the types of results returned from the async execute op.
212  SmallVector<Type, 4> resultTypes;
213  NamedAttrList attrs;
214  if (parser.parseOptionalArrowTypeList(resultTypes) ||
215  // Async execute first result is always a completion token.
216  parser.addTypeToList(tokenTy, result.types) ||
217  parser.addTypesToList(resultTypes, result.types) ||
218  // Parse operation attributes.
219  parser.parseOptionalAttrDictWithKeyword(attrs))
220  return failure();
221 
222  result.addAttributes(attrs);
223 
224  // Parse asynchronous region.
225  Region *body = result.addRegion();
226  return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
227 }
228 
229 LogicalResult ExecuteOp::verifyRegions() {
230  // Unwrap async.execute value operands types.
231  auto unwrappedTypes = llvm::map_range(operands(), [](Value operand) {
232  return operand.getType().cast<ValueType>().getValueType();
233  });
234 
235  // Verify that unwrapped argument types matches the body region arguments.
236  if (body().getArgumentTypes() != unwrappedTypes)
237  return emitOpError("async body region argument types do not match the "
238  "execute operation arguments types");
239 
240  return success();
241 }
242 
243 //===----------------------------------------------------------------------===//
244 /// CreateGroupOp
245 //===----------------------------------------------------------------------===//
246 
247 LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
248  PatternRewriter &rewriter) {
249  // Find all `await_all` users of the group.
250  llvm::SmallVector<AwaitAllOp> awaitAllUsers;
251 
252  auto isAwaitAll = [&](Operation *op) -> bool {
253  if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
254  awaitAllUsers.push_back(awaitAll);
255  return true;
256  }
257  return false;
258  };
259 
260  // Check if all users of the group are `await_all` operations.
261  if (!llvm::all_of(op->getUsers(), isAwaitAll))
262  return failure();
263 
264  // If group is only awaited without adding anything to it, we can safely erase
265  // the create operation and all users.
266  for (AwaitAllOp awaitAll : awaitAllUsers)
267  rewriter.eraseOp(awaitAll);
268  rewriter.eraseOp(op);
269 
270  return success();
271 }
272 
273 //===----------------------------------------------------------------------===//
274 /// AwaitOp
275 //===----------------------------------------------------------------------===//
276 
277 void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
278  ArrayRef<NamedAttribute> attrs) {
279  result.addOperands({operand});
280  result.attributes.append(attrs.begin(), attrs.end());
281 
282  // Add unwrapped async.value type to the returned values types.
283  if (auto valueType = operand.getType().dyn_cast<ValueType>())
284  result.addTypes(valueType.getValueType());
285 }
286 
287 static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
288  Type &resultType) {
289  if (parser.parseType(operandType))
290  return failure();
291 
292  // Add unwrapped async.value type to the returned values types.
293  if (auto valueType = operandType.dyn_cast<ValueType>())
294  resultType = valueType.getValueType();
295 
296  return success();
297 }
298 
300  Type operandType, Type resultType) {
301  p << operandType;
302 }
303 
305  Type argType = operand().getType();
306 
307  // Awaiting on a token does not have any results.
308  if (argType.isa<TokenType>() && !getResultTypes().empty())
309  return emitOpError("awaiting on a token must have empty result");
310 
311  // Awaiting on a value unwraps the async value type.
312  if (auto value = argType.dyn_cast<ValueType>()) {
313  if (*getResultType() != value.getValueType())
314  return emitOpError() << "result type " << *getResultType()
315  << " does not match async value type "
316  << value.getValueType();
317  }
318 
319  return success();
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // TableGen'd op method definitions
324 //===----------------------------------------------------------------------===//
325 
326 #define GET_OP_CLASSES
327 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
328 
329 //===----------------------------------------------------------------------===//
330 // TableGen'd type method definitions
331 //===----------------------------------------------------------------------===//
332 
333 #define GET_TYPEDEF_CLASSES
334 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
335 
336 void ValueType::print(AsmPrinter &printer) const {
337  printer << "<";
338  printer.printType(getValueType());
339  printer << '>';
340 }
341 
342 Type ValueType::parse(mlir::AsmParser &parser) {
343  Type ty;
344  if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
345  parser.emitError(parser.getNameLoc(), "failed to parse async value type");
346  return Type();
347  }
348  return ValueType::get(ty);
349 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This is a value defined by a result of an operation.
Definition: Value.h:425
virtual void printType(Type type)
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
void push_back(Block *block)
Definition: Region.h:61
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
static void printAwaitResultType(OpAsmPrinter &p, Operation *op, Type operandType, Type resultType)
Definition: Async.cpp:299
constexpr char kOperandSegmentSizesAttr[]
ExecuteOp.
Definition: Async.cpp:60
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
static constexpr const bool value
SmallVector< Value, 4 > operands
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void addOperands(ValueRange newOperands)
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
U dyn_cast() const
Definition: Types.h:270
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:58
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, Type &resultType)
Definition: Async.cpp:287
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
Parens supporting zero or more operands, or nothing.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
void addTypes(ArrayRef< Type > newTypes)
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:114
This represents an operation in an abstracted form, suitable for use with the builder APIs...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
BlockArgListType getArguments()
Definition: Block.h:76
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if the attributes keyword is present.
bool empty()
Definition: Block.h:139
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
MLIRContext * getContext() const
Get the context held by this operation state.
NamedAttrList attributes
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
Type getType() const
Return the type of this value.
Definition: Value.h:118
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
virtual ParseResult parseType(Type &result)=0
Parse a type.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with &#39;attribute...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:40
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
This base class exposes generic asm printer hooks, usable across the various derived printers...
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
bool isa() const
Definition: Types.h:254
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success. ...
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
IntegerType getI32Type()
Definition: Builders.cpp:54
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
U cast() const
Definition: Types.h:278
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.