24 #include "llvm/Support/InterleavedRange.h"
44 bool isVariadic =
false;
49 parser,
false, entryArgs, isVariadic, resultTypes,
55 GraphType grType = builder.
getGraphType(argTypes, resultTypes);
64 assert(resultAttrs.size() == resultTypes.size());
66 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.
name),
67 getResAttrsAttrName(result.
name));
80 GraphType grType = getFunctionType();
82 printer, *
this, grType.getInputs(),
83 false, grType.getResults());
85 {getFunctionTypeAttrName(),
86 getArgAttrsAttrName(),
87 getResAttrsAttrName()});
90 Region &body = this->getBody();
98 LogicalResult spirv::GraphARMOp::verifyType() {
99 if (getFunctionType().getNumResults() < 1)
100 return emitOpError(
"there should be at least one result");
104 LogicalResult spirv::GraphARMOp::verifyBody() {
105 for (
auto [index, graphArgType] :
llvm::enumerate(getArgumentTypes())) {
106 if (!isa<spirv::TensorArmType>(graphArgType)) {
107 return emitOpError(
"type of argument #")
108 << index <<
" must be a TensorArmType, but got " << graphArgType;
112 if (!isa<spirv::TensorArmType>(graphResType)) {
113 return emitOpError(
"type of result #")
114 << index <<
" must be a TensorArmType, but got " << graphResType;
119 Block &entryBlock = front();
121 unsigned numArguments = this->getNumArguments();
123 return emitOpError(
"entry block must have ")
124 << numArguments <<
" arguments to match graph signature";
126 for (
auto [index, grArgType, blockArgType] :
128 if (blockArgType != grArgType) {
129 return emitOpError(
"type of entry block argument #")
130 << index <<
'(' << blockArgType
131 <<
") must match the type of the corresponding argument in "
132 <<
"graph signature(" << grArgType <<
')';
137 GraphType grType = getFunctionType();
138 auto walkResult =
walk([grType](spirv::GraphOutputsARMOp op) ->
WalkResult {
139 if (grType.getNumResults() != op.getNumOperands())
140 return op.emitOpError(
"is returning ")
141 << op.getNumOperands()
142 <<
" value(s) but enclosing spirv.ARM.Graph requires "
143 << grType.getNumResults() <<
" result(s)";
146 op.getValue().getType();
148 if (type != grType.getResult(index))
149 return op.emitError(
"type of return operand ")
150 << index <<
" (" << type <<
") doesn't match graph result type ("
151 << grType.getResult(index) <<
")";
156 return failure(walkResult.wasInterrupted());
160 StringRef name, GraphType type,
164 state.addAttribute(getFunctionTypeAttrName(state.name),
TypeAttr::get(type));
165 state.attributes.append(attrs);
166 state.addAttribute(getEntryPointAttrName(state.name),
172 return getFunctionType().getInputs();
176 return getFunctionType().getResults();
179 Region *spirv::GraphARMOp::getCallableRegion() {
180 return isExternal() ? nullptr : &getBody();
188 auto graph = cast<GraphARMOp>((*this)->getParentOp());
191 const ArrayRef<Type> &results = graph.getFunctionType().getResults();
192 if (getNumOperands() != results.size())
193 return emitOpError(
"has ")
194 << getNumOperands() <<
" operands, but enclosing spirv.ARM.Graph (@"
195 << graph.getName() <<
") returns " << results.size();
198 if (getOperand(index).
getType() != result)
199 return emitError() <<
"type of return operand " << index <<
" ("
200 << getOperand(index).getType()
201 <<
") doesn't match spirv.ARM.Graph result type ("
203 <<
" in graph @" << graph.getName();
211 void spirv::GraphEntryPointARMOp::build(
OpBuilder &builder,
213 spirv::GraphARMOp graph,
230 FlatSymbolRefAttr var;
232 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
234 interfaceVars.push_back(var);
248 if (!interfaceVars.empty()) {
249 printer <<
", " << llvm::interleaved(interfaceVars);
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
This class is a general helper class for creating context-global objects like types,...
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
GraphType getGraphType(TypeRange inputs, TypeRange results)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A symbol reference with a reference path containing a single element.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class implements iteration on the types of a given range of values.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
constexpr char kFnNameAttrName[]
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.