24#include "llvm/Support/InterleavedRange.h"
33ParseResult spirv::GraphARMOp::parse(
OpAsmParser &parser,
44 bool isVariadic =
false;
49 parser,
false, entryArgs, isVariadic, resultTypes,
55 GraphType grType = builder.
getGraphType(argTypes, resultTypes);
56 result.addAttribute(getFunctionTypeAttrName(
result.name),
57 TypeAttr::get(grType));
64 assert(resultAttrs.size() == resultTypes.size());
66 builder,
result, entryArgs, resultAttrs, getArgAttrsAttrName(
result.name),
67 getResAttrsAttrName(
result.name));
80 GraphType grType = getFunctionType();
82 printer, *
this, grType.getInputs(),
83 false, grType.getResults());
85 {getFunctionTypeAttrName(),
86 getArgAttrsAttrName(),
87 getResAttrsAttrName()});
90 Region &body = this->getBody();
98LogicalResult spirv::GraphARMOp::verifyType() {
99 if (getFunctionType().getNumResults() < 1)
100 return emitOpError(
"there should be at least one result");
104LogicalResult spirv::GraphARMOp::verifyBody() {
105 for (
auto [
index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
106 if (!isa<spirv::TensorArmType>(graphArgType)) {
108 <<
index <<
" must be a TensorArmType, but got " << graphArgType;
111 for (
auto [
index, graphResType] : llvm::enumerate(getResultTypes())) {
112 if (!isa<spirv::TensorArmType>(graphResType)) {
114 <<
index <<
" must be a TensorArmType, but got " << graphResType;
119 Block &entryBlock = front();
121 unsigned numArguments = this->getNumArguments();
124 << numArguments <<
" arguments to match graph signature";
126 for (
auto [
index, grArgType, blockArgType] :
128 if (blockArgType != grArgType) {
129 return emitOpError(
"type of entry block argument #")
130 <<
index <<
'(' << blockArgType
131 <<
") must match the type of the corresponding argument in "
132 <<
"graph signature(" << grArgType <<
')';
137 GraphType grType = getFunctionType();
138 auto walkResult =
walk([grType](spirv::GraphOutputsARMOp op) ->
WalkResult {
139 if (grType.getNumResults() != op.getNumOperands())
140 return op.emitOpError(
"is returning ")
141 << op.getNumOperands()
142 <<
" value(s) but enclosing spirv.ARM.Graph requires "
143 << grType.getNumResults() <<
" result(s)";
146 op.getValue().getType();
147 for (
auto [
index, type] : llvm::enumerate(graphOutputOperandTypes)) {
148 if (type != grType.getResult(
index))
149 return op.emitError(
"type of return operand ")
150 <<
index <<
" (" << type <<
") doesn't match graph result type ("
151 << grType.getResult(
index) <<
")";
156 return failure(walkResult.wasInterrupted());
160 StringRef name, GraphType type,
164 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
172 return getFunctionType().getInputs();
176 return getFunctionType().getResults();
179Region *spirv::GraphARMOp::getCallableRegion() {
180 return isExternal() ?
nullptr : &getBody();
187LogicalResult spirv::GraphOutputsARMOp::verify() {
188 auto graph = cast<GraphARMOp>((*this)->getParentOp());
191 const ArrayRef<Type> &results = graph.getFunctionType().getResults();
192 if (getNumOperands() != results.size())
194 << getNumOperands() <<
" operands, but enclosing spirv.ARM.Graph (@"
195 << graph.getName() <<
") returns " << results.size();
197 for (
auto [
index,
result] : llvm::enumerate(results))
200 << getOperand(
index).getType()
201 <<
") doesn't match spirv.ARM.Graph result type ("
203 <<
" in graph @" << graph.getName();
211void spirv::GraphEntryPointARMOp::build(
OpBuilder &builder,
213 spirv::GraphARMOp graph,
215 build(builder, state, SymbolRefAttr::get(graph),
219ParseResult spirv::GraphEntryPointARMOp::parse(
OpAsmParser &parser,
230 FlatSymbolRefAttr var;
232 if (parser.parseAttribute(var, Type(),
"var_symbol", attrs))
234 interfaceVars.push_back(var);
239 result.addAttribute(
"interface",
244void spirv::GraphEntryPointARMOp::print(
OpAsmPrinter &printer) {
248 if (!interfaceVars.empty()) {
249 printer <<
", " << llvm::interleaved(interfaceVars);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
This class is a general helper class for creating context-global objects like types,...
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
GraphType getGraphType(TypeRange inputs, TypeRange results)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A symbol reference with a reference path containing a single element.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class implements iteration on the types of a given range of values.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
constexpr char kFnNameAttrName[]
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.