13#include "llvm/ADT/TypeSwitch.h"
18#include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"
20void AsyncDialect::initialize() {
23#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
26#define GET_TYPEDEF_LIST
27#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
39 "invalid region index");
40 return getBodyOperands();
44 const auto getValueOrTokenType = [](
Type type) {
45 if (
auto value = llvm::dyn_cast<ValueType>(type))
46 return value.getValueType();
49 return getValueOrTokenType(
lhs) == getValueOrTokenType(
rhs);
73 ValueRange operands, BodyBuilderFn bodyBuilder) {
75 result.addOperands(dependencies);
76 result.addOperands(operands);
79 int32_t numDependencies = dependencies.size();
80 int32_t numOperands = operands.size();
81 auto operandSegmentSizes =
88 for (
Type type : resultTypes)
89 result.addTypes(ValueType::get(type));
94 for (
Value operand : operands) {
95 auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
96 bodyBlock->
addArgument(valueType ? valueType.getValueType()
104 if (resultTypes.empty() && !bodyBuilder) {
106 }
else if (bodyBuilder) {
113 if (!getDependencies().empty())
114 p <<
" [" << getDependencies() <<
"]";
117 if (!getBodyOperands().empty()) {
120 llvm::interleaveComma(
121 getBodyOperands(), p, [&, n = 0](
Value operand)
mutable {
123 p << operand <<
" as " << argument <<
": " << operand.
getType();
131 {kOperandSegmentSizesAttr});
140 int32_t numDependencies = 0;
142 auto tokenTy = TokenType::get(ctx);
152 numDependencies = tokenArgs.size();
161 auto parseAsyncValueArg = [&]() -> ParseResult {
168 auto valueTy = llvm::dyn_cast<ValueType>(valueTypes.back());
169 unwrappedArgs.back().type = valueTy ? valueTy.getValueType() :
Type();
175 parseAsyncValueArg) ||
179 int32_t numOperands = valueArgs.size();
182 auto operandSegmentSizes =
197 result.addAttributes(attrs);
204LogicalResult ExecuteOp::verifyRegions() {
206 auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](
Value operand) {
207 return llvm::cast<ValueType>(operand.
getType()).getValueType();
211 if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
212 return emitOpError(
"async body region argument types do not match the "
213 "execute operation arguments types");
222LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
227 auto isAwaitAll = [&](
Operation *op) ->
bool {
228 if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
229 awaitAllUsers.push_back(awaitAll);
236 if (!llvm::all_of(op->getUsers(), isAwaitAll))
241 for (AwaitAllOp awaitAll : awaitAllUsers)
254 result.addOperands({operand});
255 result.attributes.append(attrs.begin(), attrs.end());
258 if (
auto valueType = llvm::dyn_cast<ValueType>(operand.
getType()))
259 result.addTypes(valueType.getValueType());
268 if (
auto valueType = llvm::dyn_cast<ValueType>(operandType))
269 resultType = valueType.getValueType();
275 Type operandType,
Type resultType) {
279LogicalResult AwaitOp::verify() {
280 Type argType = getOperand().getType();
283 if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
284 return emitOpError(
"awaiting on a token must have empty result");
287 if (
auto value = llvm::dyn_cast<ValueType>(argType)) {
288 if (*getResultType() != value.getValueType())
289 return emitOpError() <<
"result type " << *getResultType()
290 <<
" does not match async value type "
291 << value.getValueType();
306 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
311 if (argAttrs.empty())
313 assert(type.getNumInputs() == argAttrs.size());
315 builder, state, argAttrs, {},
316 getArgAttrsAttrName(state.
name), getResAttrsAttrName(state.
name));
327 getFunctionTypeAttrName(
result.name), buildFuncType,
328 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
333 p, *
this,
false, getFunctionTypeAttrName(),
334 getArgAttrsAttrName(), getResAttrsAttrName());
339LogicalResult FuncOp::verify() {
340 auto resultTypes = getResultTypes();
341 if (resultTypes.empty())
343 <<
"result is expected to be at least of size 1, but got "
344 << resultTypes.size();
346 for (
unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
347 auto type = resultTypes[i];
348 if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
349 return emitOpError() <<
"result type must be async value type or async "
350 "token type, but got "
353 if (llvm::isa<TokenType>(type) && i != 0) {
355 <<
" results' (optional) async token type is expected "
356 "to appear as the 1st return value, but got "
372 return emitOpError(
"requires a 'callee' symbol reference attribute");
376 <<
"' does not reference a valid async function";
379 auto fnType = fn.getFunctionType();
380 if (fnType.getNumInputs() != getNumOperands())
381 return emitOpError(
"incorrect number of operands for callee");
383 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
384 if (getOperand(i).
getType() != fnType.getInput(i))
385 return emitOpError(
"operand type mismatch: expected operand type ")
386 << fnType.getInput(i) <<
", but provided "
387 << getOperand(i).getType() <<
" for operand number " << i;
389 if (fnType.getNumResults() != getNumResults())
390 return emitOpError(
"incorrect number of results for callee");
392 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
393 if (getResult(i).
getType() != fnType.getResult(i)) {
395 diag.attachNote() <<
" op result types: " << getResultTypes();
396 diag.attachNote() <<
"function result types: " << fnType.getResults();
403FunctionType CallOp::getCalleeType() {
404 return FunctionType::get(
getContext(), getOperandTypes(), getResultTypes());
411LogicalResult ReturnOp::verify() {
412 auto funcOp = (*this)->getParentOfType<FuncOp>();
414 ? funcOp.getResultTypes().drop_front()
415 : funcOp.getResultTypes();
418 auto types = llvm::map_range(resultTypes, [](
const Type &
result) {
419 return llvm::cast<ValueType>(
result).getValueType();
422 if (getOperandTypes() != types)
423 return emitOpError(
"operand types do not match the types returned from "
424 "the parent FuncOp");
433#define GET_OP_CLASSES
434#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
440#define GET_TYPEDEF_CLASSES
441#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
443void ValueType::print(
AsmPrinter &printer)
const {
455 return ValueType::get(ty);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, Type &resultType)
constexpr char kOperandSegmentSizesAttr[]
ExecuteOp.
static void printAwaitResultType(OpAsmPrinter &p, Operation *op, Type operandType, Type resultType)
static std::string diag(const llvm::Value &value)
static Type getValueType(Attribute attr)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ OptionalParen
Parens supporting zero or more operands, or nothing.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
ParseResult addTypesToList(ArrayRef< Type > types, SmallVectorImpl< Type > &result)
Add the specified types to the end of the specified type list and return success.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printType(Type type)
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
StringAttr getStringAttr(const Twine &bytes)
A symbol reference with a reference path containing a single element.
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A named class for passing around the variadic flag.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.