MLIR 22.0.0git
ArmGraphOps.cpp
Go to the documentation of this file.
1//===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPV_ARM_graph operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVParsingUtils.h"
16
20#include "mlir/IR/Builders.h"
22#include "mlir/IR/Operation.h"
24#include "llvm/Support/InterleavedRange.h"
25
26using namespace mlir;
27using namespace mlir::spirv::AttrNames;
28
29//===----------------------------------------------------------------------===//
30// spirv.GraphARM
31//===----------------------------------------------------------------------===//
32
33ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
35 Builder &builder = parser.getBuilder();
36
37 // Parse the name as a symbol.
38 StringAttr nameAttr;
40 result.attributes))
41 return failure();
42
43 // Parse the function signature.
44 bool isVariadic = false;
46 SmallVector<Type> resultTypes;
49 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
50 resultAttrs))
51 return failure();
52
53 SmallVector<Type> argTypes = llvm::map_to_vector(
54 entryArgs, [](const OpAsmParser::Argument &arg) { return arg.type; });
55 GraphType grType = builder.getGraphType(argTypes, resultTypes);
56 result.addAttribute(getFunctionTypeAttrName(result.name),
57 TypeAttr::get(grType));
58
59 // If additional attributes are present, parse them.
60 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
61 return failure();
62
63 // Add the attributes to the function arguments.
64 assert(resultAttrs.size() == resultTypes.size());
66 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
67 getResAttrsAttrName(result.name));
68
69 // Parse the optional function body.
70 Region *body = result.addRegion();
71 OptionalParseResult parseResult =
72 parser.parseOptionalRegion(*body, entryArgs);
73 return failure(parseResult.has_value() && failed(*parseResult));
74}
75
76void spirv::GraphARMOp::print(OpAsmPrinter &printer) {
77 // Print graph name, signature, and control.
78 printer << " ";
79 printer.printSymbolName(getSymName());
80 GraphType grType = getFunctionType();
82 printer, *this, grType.getInputs(),
83 /*isVariadic=*/false, grType.getResults());
85 {getFunctionTypeAttrName(),
86 getArgAttrsAttrName(),
87 getResAttrsAttrName()});
88
89 // Print the body.
90 Region &body = this->getBody();
91 if (!body.empty()) {
92 printer << ' ';
93 printer.printRegion(body, /*printEntryBlockArgs=*/false,
94 /*printBlockTerminators=*/true);
95 }
96}
97
98LogicalResult spirv::GraphARMOp::verifyType() {
99 if (getFunctionType().getNumResults() < 1)
100 return emitOpError("there should be at least one result");
101 return success();
102}
103
104LogicalResult spirv::GraphARMOp::verifyBody() {
105 for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
106 if (!isa<spirv::TensorArmType>(graphArgType)) {
107 return emitOpError("type of argument #")
108 << index << " must be a TensorArmType, but got " << graphArgType;
109 }
110 }
111 for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
112 if (!isa<spirv::TensorArmType>(graphResType)) {
113 return emitOpError("type of result #")
114 << index << " must be a TensorArmType, but got " << graphResType;
115 }
116 }
117
118 if (!isExternal()) {
119 Block &entryBlock = front();
120
121 unsigned numArguments = this->getNumArguments();
122 if (entryBlock.getNumArguments() != numArguments)
123 return emitOpError("entry block must have ")
124 << numArguments << " arguments to match graph signature";
125
126 for (auto [index, grArgType, blockArgType] :
127 llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
128 if (blockArgType != grArgType) {
129 return emitOpError("type of entry block argument #")
130 << index << '(' << blockArgType
131 << ") must match the type of the corresponding argument in "
132 << "graph signature(" << grArgType << ')';
133 }
134 }
135 }
136
137 GraphType grType = getFunctionType();
138 auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
139 if (grType.getNumResults() != op.getNumOperands())
140 return op.emitOpError("is returning ")
141 << op.getNumOperands()
142 << " value(s) but enclosing spirv.ARM.Graph requires "
143 << grType.getNumResults() << " result(s)";
144
145 ValueTypeRange<OperandRange> graphOutputOperandTypes =
146 op.getValue().getType();
147 for (auto [index, type] : llvm::enumerate(graphOutputOperandTypes)) {
148 if (type != grType.getResult(index))
149 return op.emitError("type of return operand ")
150 << index << " (" << type << ") doesn't match graph result type ("
151 << grType.getResult(index) << ")";
152 }
153 return WalkResult::advance();
154 });
155
156 return failure(walkResult.wasInterrupted());
157}
158
159void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
160 StringRef name, GraphType type,
161 ArrayRef<NamedAttribute> attrs, bool entryPoint) {
163 builder.getStringAttr(name));
164 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
165 state.attributes.append(attrs);
166 state.addAttribute(getEntryPointAttrName(state.name),
167 builder.getBoolAttr(entryPoint));
168 state.addRegion();
169}
170
171ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
172 return getFunctionType().getInputs();
173}
174
175ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
176 return getFunctionType().getResults();
177}
178
179Region *spirv::GraphARMOp::getCallableRegion() {
180 return isExternal() ? nullptr : &getBody();
181}
182
183//===----------------------------------------------------------------------===//
184// spirv.GraphOutputsARM
185//===----------------------------------------------------------------------===//
186
187LogicalResult spirv::GraphOutputsARMOp::verify() {
188 auto graph = cast<GraphARMOp>((*this)->getParentOp());
189
190 // The operand number and types must match the graph signature.
191 const ArrayRef<Type> &results = graph.getFunctionType().getResults();
192 if (getNumOperands() != results.size())
193 return emitOpError("has ")
194 << getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@"
195 << graph.getName() << ") returns " << results.size();
196
197 for (auto [index, result] : llvm::enumerate(results))
198 if (getOperand(index).getType() != result)
199 return emitError() << "type of return operand " << index << " ("
200 << getOperand(index).getType()
201 << ") doesn't match spirv.ARM.Graph result type ("
202 << result << ")"
203 << " in graph @" << graph.getName();
204 return success();
205}
206
207//===----------------------------------------------------------------------===//
208// spirv.GraphEntryPointARM
209//===----------------------------------------------------------------------===//
210
211void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
212 OperationState &state,
213 spirv::GraphARMOp graph,
214 ArrayRef<Attribute> interfaceVars) {
215 build(builder, state, SymbolRefAttr::get(graph),
216 builder.getArrayAttr(interfaceVars));
217}
218
219ParseResult spirv::GraphEntryPointARMOp::parse(OpAsmParser &parser,
222 if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
223 return failure();
224
225 SmallVector<Attribute, 4> interfaceVars;
226 if (!parser.parseOptionalComma()) {
227 // Parse the interface variables.
228 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
229 // The name of the interface variable attribute is not important.
230 FlatSymbolRefAttr var;
231 NamedAttrList attrs;
232 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
233 return failure();
234 interfaceVars.push_back(var);
235 return success();
236 }))
237 return failure();
238 }
239 result.addAttribute("interface",
240 parser.getBuilder().getArrayAttr(interfaceVars));
241 return success();
242}
243
244void spirv::GraphEntryPointARMOp::print(OpAsmPrinter &printer) {
245 printer << " ";
246 printer.printSymbolName(getFn());
247 ArrayRef<Attribute> interfaceVars = getInterface().getValue();
248 if (!interfaceVars.empty()) {
249 printer << ", " << llvm::interleaved(interfaceVars);
250 }
251}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
unsigned getNumArguments()
Definition Block.h:128
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
GraphType getGraphType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
A symbol reference with a reference path containing a single element.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:207
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class implements iteration on the types of a given range of values.
Definition TypeRange.h:135
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
constexpr char kFnNameAttrName[]
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.