MLIR  22.0.0git
ArmGraphOps.cpp
Go to the documentation of this file.
1 //===- ArmGraphOps.cpp - MLIR SPIR-V SPV_ARM_graph operations -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPV_ARM_graph operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVParsingUtils.h"
16 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Operation.h"
24 #include "llvm/Support/InterleavedRange.h"
25 
26 using namespace mlir;
27 using namespace mlir::spirv::AttrNames;
28 
29 //===----------------------------------------------------------------------===//
30 // spirv.GraphARM
31 //===----------------------------------------------------------------------===//
32 
33 ParseResult spirv::GraphARMOp::parse(OpAsmParser &parser,
34  OperationState &result) {
35  Builder &builder = parser.getBuilder();
36 
37  // Parse the name as a symbol.
38  StringAttr nameAttr;
39  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
40  result.attributes))
41  return failure();
42 
43  // Parse the function signature.
44  bool isVariadic = false;
46  SmallVector<Type> resultTypes;
47  SmallVector<DictionaryAttr> resultAttrs;
49  parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
50  resultAttrs))
51  return failure();
52 
53  SmallVector<Type> argTypes = llvm::map_to_vector(
54  entryArgs, [](const OpAsmParser::Argument &arg) { return arg.type; });
55  GraphType grType = builder.getGraphType(argTypes, resultTypes);
56  result.addAttribute(getFunctionTypeAttrName(result.name),
57  TypeAttr::get(grType));
58 
59  // If additional attributes are present, parse them.
61  return failure();
62 
63  // Add the attributes to the function arguments.
64  assert(resultAttrs.size() == resultTypes.size());
66  builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
67  getResAttrsAttrName(result.name));
68 
69  // Parse the optional function body.
70  Region *body = result.addRegion();
71  OptionalParseResult parseResult =
72  parser.parseOptionalRegion(*body, entryArgs);
73  return failure(parseResult.has_value() && failed(*parseResult));
74 }
75 
77  // Print graph name, signature, and control.
78  printer << " ";
79  printer.printSymbolName(getSymName());
80  GraphType grType = getFunctionType();
82  printer, *this, grType.getInputs(),
83  /*isVariadic=*/false, grType.getResults());
85  {getFunctionTypeAttrName(),
86  getArgAttrsAttrName(),
87  getResAttrsAttrName()});
88 
89  // Print the body.
90  Region &body = this->getBody();
91  if (!body.empty()) {
92  printer << ' ';
93  printer.printRegion(body, /*printEntryBlockArgs=*/false,
94  /*printBlockTerminators=*/true);
95  }
96 }
97 
98 LogicalResult spirv::GraphARMOp::verifyType() {
99  if (getFunctionType().getNumResults() < 1)
100  return emitOpError("there should be at least one result");
101  return success();
102 }
103 
104 LogicalResult spirv::GraphARMOp::verifyBody() {
105  for (auto [index, graphArgType] : llvm::enumerate(getArgumentTypes())) {
106  if (!isa<spirv::TensorArmType>(graphArgType)) {
107  return emitOpError("type of argument #")
108  << index << " must be a TensorArmType, but got " << graphArgType;
109  }
110  }
111  for (auto [index, graphResType] : llvm::enumerate(getResultTypes())) {
112  if (!isa<spirv::TensorArmType>(graphResType)) {
113  return emitOpError("type of result #")
114  << index << " must be a TensorArmType, but got " << graphResType;
115  }
116  }
117 
118  if (!isExternal()) {
119  Block &entryBlock = front();
120 
121  unsigned numArguments = this->getNumArguments();
122  if (entryBlock.getNumArguments() != numArguments)
123  return emitOpError("entry block must have ")
124  << numArguments << " arguments to match graph signature";
125 
126  for (auto [index, grArgType, blockArgType] :
127  llvm::enumerate(getArgumentTypes(), entryBlock.getArgumentTypes())) {
128  if (blockArgType != grArgType) {
129  return emitOpError("type of entry block argument #")
130  << index << '(' << blockArgType
131  << ") must match the type of the corresponding argument in "
132  << "graph signature(" << grArgType << ')';
133  }
134  }
135  }
136 
137  GraphType grType = getFunctionType();
138  auto walkResult = walk([grType](spirv::GraphOutputsARMOp op) -> WalkResult {
139  if (grType.getNumResults() != op.getNumOperands())
140  return op.emitOpError("is returning ")
141  << op.getNumOperands()
142  << " value(s) but enclosing spirv.ARM.Graph requires "
143  << grType.getNumResults() << " result(s)";
144 
145  ValueTypeRange<OperandRange> graphOutputOperandTypes =
146  op.getValue().getType();
147  for (auto [index, type] : llvm::enumerate(graphOutputOperandTypes)) {
148  if (type != grType.getResult(index))
149  return op.emitError("type of return operand ")
150  << index << " (" << type << ") doesn't match graph result type ("
151  << grType.getResult(index) << ")";
152  }
153  return WalkResult::advance();
154  });
155 
156  return failure(walkResult.wasInterrupted());
157 }
158 
159 void spirv::GraphARMOp::build(OpBuilder &builder, OperationState &state,
160  StringRef name, GraphType type,
161  ArrayRef<NamedAttribute> attrs, bool entryPoint) {
162  state.addAttribute(SymbolTable::getSymbolAttrName(),
163  builder.getStringAttr(name));
164  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
165  state.attributes.append(attrs);
166  state.addAttribute(getEntryPointAttrName(state.name),
167  builder.getBoolAttr(entryPoint));
168  state.addRegion();
169 }
170 
171 ArrayRef<Type> spirv::GraphARMOp::getArgumentTypes() {
172  return getFunctionType().getInputs();
173 }
174 
175 ArrayRef<Type> spirv::GraphARMOp::getResultTypes() {
176  return getFunctionType().getResults();
177 }
178 
179 Region *spirv::GraphARMOp::getCallableRegion() {
180  return isExternal() ? nullptr : &getBody();
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // spirv.GraphOutputsARM
185 //===----------------------------------------------------------------------===//
186 
187 LogicalResult spirv::GraphOutputsARMOp::verify() {
188  auto graph = cast<GraphARMOp>((*this)->getParentOp());
189 
190  // The operand number and types must match the graph signature.
191  const ArrayRef<Type> &results = graph.getFunctionType().getResults();
192  if (getNumOperands() != results.size())
193  return emitOpError("has ")
194  << getNumOperands() << " operands, but enclosing spirv.ARM.Graph (@"
195  << graph.getName() << ") returns " << results.size();
196 
197  for (auto [index, result] : llvm::enumerate(results))
198  if (getOperand(index).getType() != result)
199  return emitError() << "type of return operand " << index << " ("
200  << getOperand(index).getType()
201  << ") doesn't match spirv.ARM.Graph result type ("
202  << result << ")"
203  << " in graph @" << graph.getName();
204  return success();
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // spirv.GraphEntryPointARM
209 //===----------------------------------------------------------------------===//
210 
211 void spirv::GraphEntryPointARMOp::build(OpBuilder &builder,
212  OperationState &state,
213  spirv::GraphARMOp graph,
214  ArrayRef<Attribute> interfaceVars) {
215  build(builder, state, SymbolRefAttr::get(graph),
216  builder.getArrayAttr(interfaceVars));
217 }
218 
220  OperationState &result) {
222  if (parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes))
223  return failure();
224 
225  SmallVector<Attribute, 4> interfaceVars;
226  if (!parser.parseOptionalComma()) {
227  // Parse the interface variables.
228  if (parser.parseCommaSeparatedList([&]() -> ParseResult {
229  // The name of the interface variable attribute is not important.
230  FlatSymbolRefAttr var;
231  NamedAttrList attrs;
232  if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
233  return failure();
234  interfaceVars.push_back(var);
235  return success();
236  }))
237  return failure();
238  }
239  result.addAttribute("interface",
240  parser.getBuilder().getArrayAttr(interfaceVars));
241  return success();
242 }
243 
245  printer << " ";
246  printer.printSymbolName(getFn());
247  ArrayRef<Attribute> interfaceVars = getInterface().getValue();
248  if (!interfaceVars.empty()) {
249  printer << ", " << llvm::interleaved(interfaceVars);
250  }
251 }
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
unsigned getNumArguments()
Definition: Block.h:128
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:99
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
GraphType getGraphType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:79
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
A symbol reference with a reference path containing a single element.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:207
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class implements iteration on the types of a given range of values.
Definition: TypeRange.h:135
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
constexpr char kFnNameAttrName[]
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.