MLIR  19.0.0git
MLProgramOps.cpp
Go to the documentation of this file.
1 //===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
12 
13 using namespace mlir;
14 using namespace mlir::ml_program;
15 
16 //===----------------------------------------------------------------------===//
17 // Custom asm helpers
18 //===----------------------------------------------------------------------===//
19 
20 /// Parse and print an ordering clause for a variadic of consuming tokens
21 /// and an producing token.
22 ///
23 /// Syntax:
24 /// ordering(%0, %1 -> !ml_program.token)
25 /// ordering(() -> !ml_program.token)
26 ///
27 /// If both the consuming and producing token are not present on the op, then
28 /// the clause prints nothing.
30  OpAsmParser &parser,
32  Type &produceTokenType) {
33  if (failed(parser.parseOptionalKeyword("ordering")) ||
34  failed(parser.parseLParen()))
35  return success();
36 
37  // Parse consuming token list. If there are no consuming tokens, the
38  // '()' null list represents this.
39  if (succeeded(parser.parseOptionalLParen())) {
40  if (failed(parser.parseRParen()))
41  return failure();
42  } else {
43  if (failed(parser.parseOperandList(consumeTokens,
44  /*requiredOperandCount=*/-1)))
45  return failure();
46  }
47 
48  // Parse producer token.
49  if (failed(parser.parseArrow()))
50  return failure();
51  if (failed(parser.parseType(produceTokenType)))
52  return failure();
53 
54  if (failed(parser.parseRParen()))
55  return failure();
56 
57  return success();
58 }
59 
61  OperandRange consumeTokens,
62  Type produceTokenType) {
63  if (consumeTokens.empty() && !produceTokenType)
64  return;
65 
66  p << " ordering(";
67  if (consumeTokens.empty())
68  p << "()";
69  else
70  p.printOperands(consumeTokens);
71  if (produceTokenType) {
72  p << " -> ";
73  p.printType(produceTokenType);
74  }
75  p << ")";
76 }
77 
78 /// some.op custom<TypeOrAttr>($type, $attr)
79 ///
80 /// Uninitialized:
81 /// some.op : tensor<3xi32>
82 /// Initialized to narrower type than op:
83 /// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
85  TypeAttr &typeAttr, Attribute &attr) {
86  if (succeeded(parser.parseOptionalLParen())) {
87  if (failed(parser.parseAttribute(attr)))
88  return failure();
89  if (failed(parser.parseRParen()))
90  return failure();
91  }
92 
93  Type type;
94  if (failed(parser.parseColonType(type)))
95  return failure();
96  typeAttr = TypeAttr::get(type);
97  return success();
98 }
99 
101  TypeAttr type, Attribute attr) {
102  if (attr) {
103  p << "(";
104  p.printAttribute(attr);
105  p << ")";
106  }
107 
108  p << " : ";
109  p.printAttribute(type);
110 }
111 
112 /// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
113 /// ->
114 /// some.op public @foo
115 /// some.op private @foo
117  StringAttr &symVisibilityAttr) {
118  StringRef symVisibility;
119  (void)parser.parseOptionalKeyword(&symVisibility,
120  {"public", "private", "nested"});
121  if (symVisibility.empty())
122  return parser.emitError(parser.getCurrentLocation())
123  << "expected 'public', 'private', or 'nested'";
124  if (!symVisibility.empty())
125  symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
126  return success();
127 }
128 
130  StringAttr symVisibilityAttr) {
131  if (!symVisibilityAttr)
132  p << "public";
133  else
134  p << symVisibilityAttr.getValue();
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // TableGen'd op method definitions
139 //===----------------------------------------------------------------------===//
140 
141 #define GET_OP_CLASSES
142 #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
143 
144 //===----------------------------------------------------------------------===//
145 // FuncOp
146 //===----------------------------------------------------------------------===//
147 
149  auto buildFuncType =
150  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
152  std::string &) { return builder.getFunctionType(argTypes, results); };
153 
155  parser, result, /*allowVariadic=*/false,
156  getFunctionTypeAttrName(result.name), buildFuncType,
157  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
158 }
159 
160 void FuncOp::print(OpAsmPrinter &p) {
162  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
163  getArgAttrsAttrName(), getResAttrsAttrName());
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // GlobalOp
168 //===----------------------------------------------------------------------===//
169 
171  if (!getIsMutable() && !getValue())
172  return emitOpError() << "immutable global must have an initial value";
173  return success();
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // GlobalLoadOp
178 //===----------------------------------------------------------------------===//
179 
180 GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
181  for (auto *parent = getOperation()->getParentOp(); parent;
182  parent = parent->getParentOp()) {
183  if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
184  parent, getGlobalAttr())) {
185  return nearest;
186  }
187  }
188  return {};
189 }
190 
192 GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
193  GlobalOp referrent = getGlobalOp(symbolTable);
194  if (!referrent)
195  return emitOpError() << "undefined global: " << getGlobal();
196 
197  if (referrent.getType() != getResult().getType()) {
198  return emitOpError() << "cannot load from global typed "
199  << referrent.getType() << " as "
200  << getResult().getType();
201  }
202 
203  return success();
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // GlobalLoadConstOp
208 //===----------------------------------------------------------------------===//
209 
210 GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
211  return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
212  getOperation()->getParentOp(), getGlobalAttr());
213 }
214 
216 GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
217  GlobalOp referrent = getGlobalOp(symbolTable);
218  if (!referrent)
219  return emitOpError() << "undefined global: " << getGlobal();
220 
221  if (referrent.getIsMutable())
222  return emitOpError() << "cannot load as const from mutable global "
223  << getGlobal();
224 
225  if (referrent.getType() != getResult().getType())
226  return emitOpError() << "cannot load from global typed "
227  << referrent.getType() << " as "
228  << getResult().getType();
229 
230  return success();
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // GlobalLoadGraphOp
235 //===----------------------------------------------------------------------===//
236 
237 GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
238  return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
239  getOperation()->getParentOp(), getGlobalAttr());
240 }
241 
243 GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
244  GlobalOp referrent = getGlobalOp(symbolTable);
245  if (!referrent)
246  return emitOpError() << "undefined global: " << getGlobal();
247 
248  if (referrent.getType() != getResult().getType()) {
249  return emitOpError() << "cannot load from global typed "
250  << referrent.getType() << " as "
251  << getResult().getType();
252  }
253 
254  return success();
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // GlobalStoreOp
259 //===----------------------------------------------------------------------===//
260 
261 GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
262  for (auto *parent = getOperation()->getParentOp(); parent;) {
263  if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
264  parent, getGlobalAttr())) {
265  return nearest;
266  }
267  parent = parent->getParentOp();
268  }
269  return {};
270 }
271 
273 GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
274  GlobalOp referrent = getGlobalOp(symbolTable);
275  if (!referrent)
276  return emitOpError() << "undefined global: " << getGlobal();
277 
278  if (!referrent.getIsMutable()) {
279  return emitOpError() << "cannot store to an immutable global "
280  << getGlobal();
281  }
282 
283  if (referrent.getType() != getValue().getType()) {
284  return emitOpError() << "cannot store to a global typed "
285  << referrent.getType() << " from "
286  << getValue().getType();
287  }
288 
289  return success();
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // GlobalStoreGraphOp
294 //===----------------------------------------------------------------------===//
295 
296 GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
297  return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
298  getOperation()->getParentOp(), getGlobalAttr());
299 }
300 
302 GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
303  GlobalOp referrent = getGlobalOp(symbolTable);
304  if (!referrent)
305  return emitOpError() << "undefined global: " << getGlobal();
306 
307  if (!referrent.getIsMutable()) {
308  return emitOpError() << "cannot store to an immutable global "
309  << getGlobal();
310  }
311 
312  if (referrent.getType() != getValue().getType()) {
313  return emitOpError() << "cannot store to a global typed "
314  << referrent.getType() << " from "
315  << getValue().getType();
316  }
317 
318  return success();
319 }
320 
321 //===----------------------------------------------------------------------===//
322 // SubgraphOp
323 //===----------------------------------------------------------------------===//
324 
326  auto buildFuncType =
327  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
329  std::string &) { return builder.getFunctionType(argTypes, results); };
330 
332  parser, result, /*allowVariadic=*/false,
333  getFunctionTypeAttrName(result.name), buildFuncType,
334  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
335 }
336 
339  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
340  getArgAttrsAttrName(), getResAttrsAttrName());
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // OutputOp
345 //===----------------------------------------------------------------------===//
346 
348  auto function = cast<SubgraphOp>((*this)->getParentOp());
349 
350  // The operand number and types must match the function signature.
351  const auto &results = function.getFunctionType().getResults();
352  if (getNumOperands() != results.size())
353  return emitOpError("has ")
354  << getNumOperands() << " operands, but enclosing function (@"
355  << function.getName() << ") outputs " << results.size();
356 
357  for (unsigned i = 0, e = results.size(); i != e; ++i)
358  if (getOperand(i).getType() != results[i])
359  return emitError() << "type of output operand " << i << " ("
360  << getOperand(i).getType()
361  << ") doesn't match function result type ("
362  << results[i] << ")"
363  << " in function @" << function.getName();
364 
365  return success();
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // ReturnOp
370 //===----------------------------------------------------------------------===//
371 
373  auto function = cast<FuncOp>((*this)->getParentOp());
374 
375  // The operand number and types must match the function signature.
376  const auto &results = function.getFunctionType().getResults();
377  if (getNumOperands() != results.size())
378  return emitOpError("has ")
379  << getNumOperands() << " operands, but enclosing function (@"
380  << function.getName() << ") returns " << results.size();
381 
382  for (unsigned i = 0, e = results.size(); i != e; ++i)
383  if (getOperand(i).getType() != results[i])
384  return emitError() << "type of return operand " << i << " ("
385  << getOperand(i).getType()
386  << ") doesn't match function result type ("
387  << results[i] << ")"
388  << " in function @" << function.getName();
389 
390  return success();
391 }
static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
static void printTokenOrdering(OpAsmPrinter &p, Operation *op, OperandRange consumeTokens, Type produceTokenType)
static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, StringAttr symVisibilityAttr)
static ParseResult parseTypedInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
some.op custom<TypeOrAttr>($type, $attr)
static ParseResult parseTokenOrdering(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &consumeTokens, Type &produceTokenType)
Parse and print an ordering clause for a variadic of consuming tokens and an producing token.
static ParseResult parseSymbolVisibility(OpAsmParser &parser, StringAttr &symVisibilityAttr)
some.op custom<SymbolVisibility>($sym_visibility) $sym_name -> some.op public @foo some....
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents success/failure for parsing-like operations that find it important to chain tog...
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.