14 #include "llvm/ADT/MapVector.h"
15 #include "llvm/ADT/TypeSwitch.h"
20 #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
22 constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;
24 void AsyncDialect::initialize() {
27 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
30 #define GET_TYPEDEF_LIST
31 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
42 auto executeOp = (*this)->getParentOfType<ExecuteOp>();
44 llvm::map_range(executeOp.getBodyResults(), [](
const OpResult &result) {
45 return result.getType().cast<ValueType>().getValueType();
48 if (getOperandTypes() != types)
49 return emitOpError(
"operand types do not match the types returned from "
50 "the parent ExecuteOp");
56 YieldOp::getMutableSuccessorOperands(std::optional<unsigned> index) {
57 return getOperandsMutable();
67 ExecuteOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
68 assert(index && *index == 0 &&
"invalid region index");
69 return getBodyOperands();
72 bool ExecuteOp::areTypesCompatible(
Type lhs,
Type rhs) {
73 const auto getValueOrTokenType = [](
Type type) {
74 if (
auto value = type.dyn_cast<ValueType>())
75 return value.getValueType();
78 return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
81 void ExecuteOp::getSuccessorRegions(std::optional<unsigned> index,
86 assert(*index == 0 &&
"invalid region index");
98 ValueRange operands, BodyBuilderFn bodyBuilder) {
104 int32_t numDependencies = dependencies.size();
105 int32_t numOperands = operands.size();
106 auto operandSegmentSizes =
113 for (
Type type : resultTypes)
114 result.
addTypes(ValueType::get(type));
120 for (
Value operand : operands) {
121 auto valueType = operand.getType().dyn_cast<ValueType>();
122 bodyBlock.
addArgument(valueType ? valueType.getValueType()
130 if (resultTypes.empty() && !bodyBuilder) {
134 }
else if (bodyBuilder) {
143 if (!getDependencies().empty())
144 p <<
" [" << getDependencies() <<
"]";
147 if (!getBodyOperands().empty()) {
149 Block *entry = getBodyRegion().
empty() ? nullptr : &getBodyRegion().front();
150 llvm::interleaveComma(
151 getBodyOperands(), p, [&, n = 0](
Value operand)
mutable {
153 p << operand <<
" as " << argument <<
": " << operand.
getType();
161 {kOperandSegmentSizesAttr});
170 int32_t numDependencies = 0;
172 auto tokenTy = TokenType::get(ctx);
182 numDependencies = tokenArgs.size();
198 auto valueTy = valueTypes.back().dyn_cast<ValueType>();
199 unwrappedArgs.back().type = valueTy ? valueTy.getValueType() :
Type();
205 parseAsyncValueArg) ||
209 int32_t numOperands = valueArgs.size();
212 auto operandSegmentSizes =
236 auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](
Value operand) {
237 return operand.
getType().cast<ValueType>().getValueType();
241 if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
242 return emitOpError(
"async body region argument types do not match the "
243 "execute operation arguments types");
257 auto isAwaitAll = [&](
Operation *op) ->
bool {
258 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
259 awaitAllUsers.push_back(awaitAll);
266 if (!llvm::all_of(op->getUsers(), isAwaitAll))
271 for (AwaitAllOp awaitAll : awaitAllUsers)
289 result.
addTypes(valueType.getValueType());
298 if (
auto valueType = operandType.
dyn_cast<ValueType>())
299 resultType = valueType.getValueType();
305 Type operandType,
Type resultType) {
310 Type argType = getOperand().getType();
313 if (argType.
isa<TokenType>() && !getResultTypes().empty())
314 return emitOpError(
"awaiting on a token must have empty result");
317 if (
auto value = argType.
dyn_cast<ValueType>()) {
318 if (*getResultType() != value.getValueType())
319 return emitOpError() <<
"result type " << *getResultType()
320 <<
" does not match async value type "
321 << value.getValueType();
336 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
341 if (argAttrs.empty())
343 assert(type.getNumInputs() == argAttrs.size());
345 builder, state, argAttrs, std::nullopt,
346 getArgAttrsAttrName(state.
name), getResAttrsAttrName(state.
name));
356 parser, result,
false,
357 getFunctionTypeAttrName(result.
name), buildFuncType,
358 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
363 p, *
this,
false, getFunctionTypeAttrName(),
364 getArgAttrsAttrName(), getResAttrsAttrName());
370 auto resultTypes = getResultTypes();
371 if (resultTypes.empty())
373 <<
"result is expected to be at least of size 1, but got "
374 << resultTypes.size();
376 for (
unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
377 auto type = resultTypes[i];
378 if (!type.isa<TokenType>() && !type.isa<ValueType>())
379 return emitOpError() <<
"result type must be async value type or async "
380 "token type, but got "
383 if (type.isa<TokenType>() && i != 0) {
385 <<
" results' (optional) async token type is expected "
386 "to appear as the 1st return value, but got "
402 return emitOpError(
"requires a 'callee' symbol reference attribute");
405 return emitOpError() <<
"'" << fnAttr.getValue()
406 <<
"' does not reference a valid async function";
409 auto fnType = fn.getFunctionType();
410 if (fnType.getNumInputs() != getNumOperands())
411 return emitOpError(
"incorrect number of operands for callee");
413 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
414 if (getOperand(i).getType() != fnType.getInput(i))
415 return emitOpError(
"operand type mismatch: expected operand type ")
416 << fnType.getInput(i) <<
", but provided "
417 << getOperand(i).getType() <<
" for operand number " << i;
419 if (fnType.getNumResults() != getNumResults())
420 return emitOpError(
"incorrect number of results for callee");
422 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
423 if (getResult(i).getType() != fnType.getResult(i)) {
424 auto diag = emitOpError(
"result type mismatch at index ") << i;
425 diag.attachNote() <<
" op result types: " << getResultTypes();
426 diag.attachNote() <<
"function result types: " << fnType.getResults();
433 FunctionType CallOp::getCalleeType() {
434 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
442 auto funcOp = (*this)->getParentOfType<FuncOp>();
444 ? funcOp.getResultTypes().drop_front()
445 : funcOp.getResultTypes();
448 auto types = llvm::map_range(resultTypes, [](
const Type &result) {
449 return result.
cast<ValueType>().getValueType();
452 if (getOperandTypes() != types)
453 return emitOpError(
"operand types do not match the types returned from "
454 "the parent FuncOp");
463 #define GET_OP_CLASSES
464 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
470 #define GET_TYPEDEF_CLASSES
471 #include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
479 Type ValueType::parse(mlir::AsmParser &parser) {
481 if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
482 parser.emitError(parser.getNameLoc(), "failed to parse async value type");
485 return ValueType::get(ty);
static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, Type &resultType)
constexpr char kOperandSegmentSizesAttr[]
ExecuteOp.
static void printAwaitResultType(OpAsmPrinter &p, Operation *op, Type operandType, Type resultType)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
@ OptionalParen
Parens supporting zero or more operands, or nothing.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
StringAttr getStringAttr(const Twine &bytes)
A symbol reference with a reference path containing a single element.
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
MLIRContext * getContext() const
Get the context held by this operation state.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.