MLIR 22.0.0git
MLProgramOps.cpp
Go to the documentation of this file.
1//===- MLProgramOps.cpp - MLProgram dialect ops implementation ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/IR/Builders.h"
12
13using namespace mlir;
14using namespace mlir::ml_program;
15
16//===----------------------------------------------------------------------===//
17// Custom asm helpers
18//===----------------------------------------------------------------------===//
19
20/// Parse and print an ordering clause for a variadic of consuming tokens
21/// and an producing token.
22///
23/// Syntax:
24/// ordering(%0, %1 -> !ml_program.token)
25/// ordering(() -> !ml_program.token)
26///
27/// If both the consuming and producing token are not present on the op, then
28/// the clause prints nothing.
29static ParseResult parseTokenOrdering(
30 OpAsmParser &parser,
32 Type &produceTokenType) {
33 if (failed(parser.parseOptionalKeyword("ordering")) ||
34 failed(parser.parseLParen()))
35 return success();
36
37 // Parse consuming token list. If there are no consuming tokens, the
38 // '()' null list represents this.
39 if (succeeded(parser.parseOptionalLParen())) {
40 if (failed(parser.parseRParen()))
41 return failure();
42 } else {
43 if (failed(parser.parseOperandList(consumeTokens,
44 /*requiredOperandCount=*/-1)))
45 return failure();
46 }
47
48 // Parse producer token.
49 if (failed(parser.parseArrow()))
50 return failure();
51 if (failed(parser.parseType(produceTokenType)))
52 return failure();
53
54 if (failed(parser.parseRParen()))
55 return failure();
56
57 return success();
58}
59
61 OperandRange consumeTokens,
62 Type produceTokenType) {
63 if (consumeTokens.empty() && !produceTokenType)
64 return;
65
66 p << " ordering(";
67 if (consumeTokens.empty())
68 p << "()";
69 else
70 p.printOperands(consumeTokens);
71 if (produceTokenType) {
72 p << " -> ";
73 p.printType(produceTokenType);
74 }
75 p << ")";
76}
77
78/// some.op custom<TypeOrAttr>($type, $attr)
79///
80/// Uninitialized:
81/// some.op : tensor<3xi32>
82/// Initialized to narrower type than op:
83/// some.op (dense<0> : tensor<3xi32>) : tensor<?xi32>
84static ParseResult parseTypedInitialValue(OpAsmParser &parser,
85 TypeAttr &typeAttr, Attribute &attr) {
86 if (succeeded(parser.parseOptionalLParen())) {
87 if (failed(parser.parseAttribute(attr)))
88 return failure();
89 if (failed(parser.parseRParen()))
90 return failure();
91 }
92
93 Type type;
94 if (failed(parser.parseColonType(type)))
95 return failure();
96 typeAttr = TypeAttr::get(type);
97 return success();
98}
99
101 TypeAttr type, Attribute attr) {
102 if (attr) {
103 p << "(";
104 p.printAttribute(attr);
105 p << ")";
106 }
107
108 p << " : ";
109 p.printAttribute(type);
110}
111
112/// some.op custom<SymbolVisibility>($sym_visibility) $sym_name
113/// ->
114/// some.op public @foo
115/// some.op private @foo
116static ParseResult parseSymbolVisibility(OpAsmParser &parser,
117 StringAttr &symVisibilityAttr) {
118 StringRef symVisibility;
119 (void)parser.parseOptionalKeyword(&symVisibility,
120 {"public", "private", "nested"});
121 if (symVisibility.empty())
122 return parser.emitError(parser.getCurrentLocation())
123 << "expected 'public', 'private', or 'nested'";
124 if (!symVisibility.empty())
125 symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility);
126 return success();
127}
128
130 StringAttr symVisibilityAttr) {
131 if (!symVisibilityAttr)
132 p << "public";
133 else
134 p << symVisibilityAttr.getValue();
135}
136
137//===----------------------------------------------------------------------===//
138// TableGen'd op method definitions
139//===----------------------------------------------------------------------===//
140
141#define GET_OP_CLASSES
142#include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc"
143
144//===----------------------------------------------------------------------===//
145// FuncOp
146//===----------------------------------------------------------------------===//
147
148ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
149 auto buildFuncType =
150 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
152 std::string &) { return builder.getFunctionType(argTypes, results); };
153
155 parser, result, /*allowVariadic=*/false,
156 getFunctionTypeAttrName(result.name), buildFuncType,
157 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
158}
159
160void FuncOp::print(OpAsmPrinter &p) {
162 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
163 getArgAttrsAttrName(), getResAttrsAttrName());
164}
165
166//===----------------------------------------------------------------------===//
167// GlobalOp
168//===----------------------------------------------------------------------===//
169
170LogicalResult GlobalOp::verify() {
171 if (!getIsMutable() && !getValue())
172 return emitOpError() << "immutable global must have an initial value";
173 return success();
174}
175
176//===----------------------------------------------------------------------===//
177// GlobalLoadOp
178//===----------------------------------------------------------------------===//
179
180GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) {
181 for (auto *parent = getOperation()->getParentOp(); parent;
182 parent = parent->getParentOp()) {
183 if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
184 parent, getGlobalAttr())) {
185 return nearest;
186 }
187 }
188 return {};
189}
190
191LogicalResult
192GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
193 GlobalOp referrent = getGlobalOp(symbolTable);
194 if (!referrent)
195 return emitOpError() << "undefined global: " << getGlobal();
196
197 if (referrent.getType() != getResult().getType()) {
198 return emitOpError() << "cannot load from global typed "
199 << referrent.getType() << " as "
200 << getResult().getType();
201 }
202
203 return success();
204}
205
206//===----------------------------------------------------------------------===//
207// GlobalLoadConstOp
208//===----------------------------------------------------------------------===//
209
210GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) {
211 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
212 getOperation()->getParentOp(), getGlobalAttr());
213}
214
215LogicalResult
216GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
217 GlobalOp referrent = getGlobalOp(symbolTable);
218 if (!referrent)
219 return emitOpError() << "undefined global: " << getGlobal();
220
221 if (referrent.getIsMutable())
222 return emitOpError() << "cannot load as const from mutable global "
223 << getGlobal();
224
225 if (referrent.getType() != getResult().getType())
226 return emitOpError() << "cannot load from global typed "
227 << referrent.getType() << " as "
228 << getResult().getType();
229
230 return success();
231}
232
233//===----------------------------------------------------------------------===//
234// GlobalLoadGraphOp
235//===----------------------------------------------------------------------===//
236
237GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
238 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
239 getOperation()->getParentOp(), getGlobalAttr());
240}
241
242LogicalResult
243GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
244 GlobalOp referrent = getGlobalOp(symbolTable);
245 if (!referrent)
246 return emitOpError() << "undefined global: " << getGlobal();
247
248 if (referrent.getType() != getResult().getType()) {
249 return emitOpError() << "cannot load from global typed "
250 << referrent.getType() << " as "
251 << getResult().getType();
252 }
253
254 return success();
255}
256
257//===----------------------------------------------------------------------===//
258// GlobalStoreOp
259//===----------------------------------------------------------------------===//
260
261GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) {
262 for (auto *parent = getOperation()->getParentOp(); parent;) {
263 if (auto nearest = symbolTable.lookupNearestSymbolFrom<GlobalOp>(
264 parent, getGlobalAttr())) {
265 return nearest;
266 }
267 parent = parent->getParentOp();
268 }
269 return {};
270}
271
272LogicalResult
273GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
274 GlobalOp referrent = getGlobalOp(symbolTable);
275 if (!referrent)
276 return emitOpError() << "undefined global: " << getGlobal();
277
278 if (!referrent.getIsMutable()) {
279 return emitOpError() << "cannot store to an immutable global "
280 << getGlobal();
281 }
282
283 if (referrent.getType() != getValue().getType()) {
284 return emitOpError() << "cannot store to a global typed "
285 << referrent.getType() << " from "
286 << getValue().getType();
287 }
288
289 return success();
290}
291
292//===----------------------------------------------------------------------===//
293// GlobalStoreGraphOp
294//===----------------------------------------------------------------------===//
295
296GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
297 return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
298 getOperation()->getParentOp(), getGlobalAttr());
299}
300
301LogicalResult
302GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
303 GlobalOp referrent = getGlobalOp(symbolTable);
304 if (!referrent)
305 return emitOpError() << "undefined global: " << getGlobal();
306
307 if (!referrent.getIsMutable()) {
308 return emitOpError() << "cannot store to an immutable global "
309 << getGlobal();
310 }
311
312 if (referrent.getType() != getValue().getType()) {
313 return emitOpError() << "cannot store to a global typed "
314 << referrent.getType() << " from "
315 << getValue().getType();
316 }
317
318 return success();
319}
320
321//===----------------------------------------------------------------------===//
322// SubgraphOp
323//===----------------------------------------------------------------------===//
324
325ParseResult SubgraphOp::parse(OpAsmParser &parser, OperationState &result) {
326 auto buildFuncType =
327 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
329 std::string &) { return builder.getFunctionType(argTypes, results); };
330
332 parser, result, /*allowVariadic=*/false,
333 getFunctionTypeAttrName(result.name), buildFuncType,
334 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
335}
336
337void SubgraphOp::print(OpAsmPrinter &p) {
339 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
340 getArgAttrsAttrName(), getResAttrsAttrName());
341}
342
343//===----------------------------------------------------------------------===//
344// OutputOp
345//===----------------------------------------------------------------------===//
346
347LogicalResult OutputOp::verify() {
348 auto function = cast<SubgraphOp>((*this)->getParentOp());
349
350 // The operand number and types must match the function signature.
351 const auto &results = function.getFunctionType().getResults();
352 if (getNumOperands() != results.size())
353 return emitOpError("has ")
354 << getNumOperands() << " operands, but enclosing function (@"
355 << function.getName() << ") outputs " << results.size();
356
357 for (unsigned i = 0, e = results.size(); i != e; ++i)
358 if (getOperand(i).getType() != results[i])
359 return emitError() << "type of output operand " << i << " ("
360 << getOperand(i).getType()
361 << ") doesn't match function result type ("
362 << results[i] << ")"
363 << " in function @" << function.getName();
364
365 return success();
366}
367
368//===----------------------------------------------------------------------===//
369// ReturnOp
370//===----------------------------------------------------------------------===//
371
372LogicalResult ReturnOp::verify() {
373 auto function = cast<FuncOp>((*this)->getParentOp());
374
375 // The operand number and types must match the function signature.
376 const auto &results = function.getFunctionType().getResults();
377 if (getNumOperands() != results.size())
378 return emitOpError("has ")
379 << getNumOperands() << " operands, but enclosing function (@"
380 << function.getName() << ") returns " << results.size();
381
382 for (unsigned i = 0, e = results.size(); i != e; ++i)
383 if (getOperand(i).getType() != results[i])
384 return emitError() << "type of return operand " << i << " ("
385 << getOperand(i).getType()
386 << ") doesn't match function result type ("
387 << results[i] << ")"
388 << " in function @" << function.getName();
389
390 return success();
391}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr)
static void printTokenOrdering(OpAsmPrinter &p, Operation *op, OperandRange consumeTokens, Type produceTokenType)
static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, StringAttr symVisibilityAttr)
static ParseResult parseTypedInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr)
some.op custom<TypeOrAttr>($type, $attr)
static ParseResult parseTokenOrdering(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &consumeTokens, Type &produceTokenType)
Parse and print an ordering clause for a variadic of consuming tokens and an producing token.
static ParseResult parseSymbolVisibility(OpAsmParser &parser, StringAttr &symVisibilityAttr)
some.op custom<SymbolVisibility>($sym_visibility) $sym_name -> some.op public @foo some....
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This represents an operation in an abstracted form, suitable for use with the builder APIs.