MLIR 22.0.0git
WasmSSAOps.cpp
Go to the documentation of this file.
1//===- WasmSSAOps.cpp - WasmSSA dialect operations ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===---------------------------------------------------------------------===//
8
11
12#include "mlir/IR/Attributes.h"
13#include "mlir/IR/Builders.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/Region.h"
18#include "mlir/IR/SymbolTable.h"
20#include "llvm/Support/Casting.h"
21
22//===----------------------------------------------------------------------===//
23// TableGen'd op method definitions
24//===----------------------------------------------------------------------===//
25
26using namespace mlir;
27namespace {
28ParseResult parseElseRegion(OpAsmParser &opParser, Region &elseRegion) {
29 std::string keyword;
30 std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
31 if (keyword == "else")
32 return opParser.parseRegion(elseRegion);
33 return ParseResult::success();
34}
35
36void printElseRegion(OpAsmPrinter &opPrinter, Operation *op,
37 Region &elseRegion) {
38 if (elseRegion.empty())
39 return;
40 opPrinter.printKeywordOrString("else ");
41 opPrinter.printRegion(elseRegion);
42}
43} // namespace
44
45#define GET_OP_CLASSES
46#include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.cpp.inc"
47
49#include "mlir/IR/Types.h"
50#include "llvm/Support/LogicalResult.h"
51
52using namespace wasmssa;
53
54namespace {
55inline LogicalResult
56inferTeeGetResType(ValueRange operands,
57 SmallVectorImpl<Type> &inferredReturnTypes) {
58 if (operands.empty())
59 return failure();
60 auto opType = dyn_cast<LocalRefType>(operands.front().getType());
61 if (!opType)
62 return failure();
63 inferredReturnTypes.push_back(opType.getElementType());
64 return success();
65}
66
67ParseResult parseImportOp(OpAsmParser &parser, OperationState &result) {
68 std::string importName;
69 auto *ctx = parser.getContext();
70 ParseResult res = parser.parseString(&importName);
71 result.addAttribute("importName", StringAttr::get(ctx, importName));
72
73 std::string fromStr;
74 res = parser.parseKeywordOrString(&fromStr);
75 if (failed(res) || fromStr != "from")
76 return failure();
77
78 std::string moduleName;
79 res = parser.parseString(&moduleName);
80 if (failed(res))
81 return failure();
82 result.addAttribute("moduleName", StringAttr::get(ctx, moduleName));
83
84 std::string asStr;
85 res = parser.parseKeywordOrString(&asStr);
86 if (failed(res) || asStr != "as")
87 return failure();
88
89 StringAttr symbolName;
90 res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
91 result.attributes);
92 return res;
93}
94} // namespace
95
96//===----------------------------------------------------------------------===//
97// BlockOp
98//===----------------------------------------------------------------------===//
99
100Block *BlockOp::getLabelTarget() { return getTarget(); }
101
102//===----------------------------------------------------------------------===//
103// BlockReturnOp
104//===----------------------------------------------------------------------===//
105
106std::size_t BlockReturnOp::getExitLevel() { return 0; }
107
108Block *BlockReturnOp::getTarget() {
109 return cast<LabelBranchingOpInterface>(getOperation())
110 .getTargetOp()
111 .getOperation()
112 ->getSuccessor(0);
113}
114
115//===----------------------------------------------------------------------===//
116// ExtendLowBitsSOp
117//===----------------------------------------------------------------------===//
118
119LogicalResult ExtendLowBitsSOp::verify() {
120 auto bitsToTake = getBitsToTake().getValue().getLimitedValue();
121 if (bitsToTake != 32 && bitsToTake != 16 && bitsToTake != 8)
122 return emitError("extend op can only take 8, 16 or 32 bits. Got ")
123 << bitsToTake;
124
125 if (bitsToTake >= getInput().getType().getIntOrFloatBitWidth())
126 return emitError("trying to extend the ")
127 << bitsToTake << " low bits from a " << getInput().getType()
128 << " value is illegal";
129 return success();
130}
131
132//===----------------------------------------------------------------------===//
133// FuncOp
134//===----------------------------------------------------------------------===//
135
136Block *FuncOp::addEntryBlock() {
137 if (!getBody().empty()) {
138 emitError("adding entry block to a FuncOp which already has one");
139 return &getBody().front();
140 }
141 Block &block = getBody().emplaceBlock();
142 for (auto argType : getFunctionType().getInputs())
143 block.addArgument(LocalRefType::get(argType), getLoc());
144 return &block;
145}
146
147void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
148 StringRef symbol, FunctionType funcType) {
149 FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {});
150}
151
152ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
153 auto *ctx = parser.getContext();
154 std::string visibilityString;
155 auto loc = parser.getNameLoc();
156 ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
157 bool exported{false};
158 if (res.succeeded()) {
159 if (visibilityString != "exported")
160 return parser.emitError(
161 loc, "expecting either `exported` or symbol name. got ")
162 << visibilityString;
163 exported = true;
164 }
165
166 auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes,
167 ArrayRef<Type> results,
169 std::string &) {
170 SmallVector<Type> argTypesWithoutLocal{};
171 argTypesWithoutLocal.reserve(argTypes.size());
172 llvm::for_each(argTypes, [&parser, &argTypesWithoutLocal](Type argType) {
173 auto refType = dyn_cast<LocalRefType>(argType);
174 auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
175 if (!refType) {
176 mlir::emitError(loc, "invalid type for wasm.func argument. Expecting "
177 "!wasm<local T>, got ")
178 << argType;
179 return;
180 }
181 argTypesWithoutLocal.push_back(refType.getElementType());
182 });
183
184 return builder.getFunctionType(argTypesWithoutLocal, results);
185 };
187 parser, result, /*allowVariadic=*/false,
188 getFunctionTypeAttrName(result.name), buildFuncType,
189 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
190 if (exported)
191 result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
192 return funcParseRes;
193}
194
195LogicalResult FuncOp::verifyBody() {
196 if (getBody().empty())
197 return success();
198 Block &entry = getBody().front();
199 if (entry.getNumArguments() != getFunctionType().getNumInputs())
200 return emitError("entry block should have same number of arguments as "
201 "function type. Function type has ")
202 << getFunctionType().getNumInputs() << ", entry block has "
203 << entry.getNumArguments();
204
205 for (auto [argNo, funcSignatureType, blockType] : llvm::enumerate(
206 getFunctionType().getInputs(), entry.getArgumentTypes())) {
207 auto blockLocalRefType = dyn_cast<LocalRefType>(blockType);
208 if (!blockLocalRefType)
209 return emitError("entry block argument type should be LocalRefType, got ")
210 << blockType << " for block argument " << argNo;
211 if (blockLocalRefType.getElementType() != funcSignatureType)
212 return emitError("func argument type #")
213 << argNo << "(" << funcSignatureType
214 << ") doesn't match entry block referenced type ("
215 << blockLocalRefType.getElementType() << ")";
216 }
217 return success();
218}
219
220void FuncOp::print(OpAsmPrinter &p) {
221 /// If exported, print it before and mask it before printing
222 /// using generic interface.
223 auto exported = getExported();
224 if (exported) {
225 p << " exported";
226 removeExportedAttr();
227 }
229 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
230 getArgAttrsAttrName(), getResAttrsAttrName());
231 if (exported)
232 setExported(true);
233}
234
235//===----------------------------------------------------------------------===//
236// FuncImportOp
237//===----------------------------------------------------------------------===//
238
239void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
240 StringRef symbol, StringRef moduleName,
241 StringRef importName, FunctionType type) {
242 FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
243 type, {}, {});
244}
245
246//===----------------------------------------------------------------------===//
247// GlobalOp
248//===----------------------------------------------------------------------===//
249// Custom formats
250ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
251 StringAttr symbolName;
252 Type globalType;
253 auto *ctx = parser.getContext();
254 std::string visibilityString;
255 auto loc = parser.getNameLoc();
256 ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
257 if (res.succeeded()) {
258 if (visibilityString != "exported")
259 return parser.emitError(
260 loc, "expecting either `exported` or symbol name. got ")
261 << visibilityString;
262 result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
263 }
264
265 res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
266 result.attributes);
267 res = parser.parseType(globalType);
268 result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType));
269 std::string mutableString;
270 res = parser.parseOptionalKeywordOrString(&mutableString);
271 if (res.succeeded() && mutableString == "mutable")
272 result.addAttribute("isMutable", UnitAttr::get(ctx));
273
274 res = parser.parseColon();
275 Region *globalInitRegion = result.addRegion();
276 res = parser.parseRegion(*globalInitRegion);
277 return res;
278}
279
280void GlobalOp::print(OpAsmPrinter &printer) {
281 if (getExported())
282 printer << " exported";
283 printer << " @" << getSymName().str() << " " << getType();
284 if (getIsMutable())
285 printer << " mutable";
286 printer << " :";
287 Region &body = getRegion();
288 if (!body.empty()) {
289 printer << ' ';
290 printer.printRegion(body, /*printEntryBlockArgs=*/false,
291 /*printBlockTerminators=*/true);
292 }
293}
294
295//===----------------------------------------------------------------------===//
296// GlobalGetOp
297//===----------------------------------------------------------------------===//
298
299LogicalResult
300GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
301 // If the parent requires a constant context, verify that global.get is a
302 // constant as defined per the wasm standard.
303 if (!this->getOperation()
304 ->getParentWithTrait<ConstantExpressionInitializerOpTrait>())
305 return success();
307 StringRef referencedSymbol = getGlobal();
308 Operation *definitionOp = symbolTable.lookupSymbolIn(
309 symTabOp, StringAttr::get(this->getContext(), referencedSymbol));
310 if (!definitionOp)
311 return emitError() << "symbol @" << referencedSymbol << " is undefined";
312 auto definitionImport = dyn_cast<GlobalImportOp>(definitionOp);
313 if (!definitionImport || definitionImport.getIsMutable()) {
314 return emitError("global.get op is considered constant if it's referring "
315 "to a import.global symbol marked non-mutable");
316 }
317 return success();
318}
319
320//===----------------------------------------------------------------------===//
321// GlobalImportOp
322//===----------------------------------------------------------------------===//
323
324ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
325 auto *ctx = parser.getContext();
326 ParseResult res = parseImportOp(parser, result);
327 if (res.failed())
328 return failure();
329 std::string mutableOrSymVisString;
330 res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
331 if (res.succeeded() && mutableOrSymVisString == "mutable") {
332 result.addAttribute("isMutable", UnitAttr::get(ctx));
333 }
334
335 res = parser.parseColon();
336
337 Type importedType;
338 res = parser.parseType(importedType);
339 if (res.succeeded())
340 result.addAttribute(getTypeAttrName(result.name),
341 TypeAttr::get(importedType));
342 return res;
343}
344
345void GlobalImportOp::print(OpAsmPrinter &printer) {
346 printer << " \"" << getImportName() << "\" from \"" << getModuleName()
347 << "\" as @" << getSymName();
348 if (getIsMutable())
349 printer << " mutable";
350 printer << " : " << getType();
351}
352
353//===----------------------------------------------------------------------===//
354// IfOp
355//===----------------------------------------------------------------------===//
356
357Block *IfOp::getLabelTarget() { return getTarget(); }
358
359//===----------------------------------------------------------------------===//
360// LocalOp
361//===----------------------------------------------------------------------===//
362
363LogicalResult LocalOp::inferReturnTypes(
364 MLIRContext *context, ::std::optional<Location> location,
365 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
366 RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
367 LocalOp::GenericAdaptor<ValueRange> adaptor{operands, attributes, properties,
368 regions};
369 auto type = adaptor.getTypeAttr();
370 if (!type)
371 return failure();
372 auto resType = LocalRefType::get(type.getContext(), type.getValue());
373 inferredReturnTypes.push_back(resType);
374 return success();
375}
376
377//===----------------------------------------------------------------------===//
378// LocalGetOp
379//===----------------------------------------------------------------------===//
380
381LogicalResult LocalGetOp::inferReturnTypes(
382 MLIRContext *context, ::std::optional<Location> location,
383 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
384 RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
385 return inferTeeGetResType(operands, inferredReturnTypes);
386}
387
388//===----------------------------------------------------------------------===//
389// LocalSetOp
390//===----------------------------------------------------------------------===//
391
392LogicalResult LocalSetOp::verify() {
393 if (getLocalVar().getType().getElementType() != getValue().getType())
394 return emitError("input type and result type of local.set do not match");
395 return success();
396}
397
398//===----------------------------------------------------------------------===//
399// LocalTeeOp
400//===----------------------------------------------------------------------===//
401
402LogicalResult LocalTeeOp::inferReturnTypes(
403 MLIRContext *context, ::std::optional<Location> location,
404 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
405 RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
406 return inferTeeGetResType(operands, inferredReturnTypes);
407}
408
409LogicalResult LocalTeeOp::verify() {
410 if (getLocalVar().getType().getElementType() != getValue().getType() ||
411 getValue().getType() != getResult().getType())
412 return emitError("input type and output type of local.tee do not match");
413 return success();
414}
415
416//===----------------------------------------------------------------------===//
417// LoopOp
418//===----------------------------------------------------------------------===//
419
420Block *LoopOp::getLabelTarget() { return &getBody().front(); }
421
422//===----------------------------------------------------------------------===//
423// ReinterpretOp
424//===----------------------------------------------------------------------===//
425
426LogicalResult ReinterpretOp::verify() {
427 auto inT = getInput().getType();
428 auto resT = getResult().getType();
429 if (inT == resT)
430 return emitError("reinterpret input and output type should be distinct");
431 if (inT.getIntOrFloatBitWidth() != resT.getIntOrFloatBitWidth())
432 return emitError() << "input type (" << inT << ") and output type (" << resT
433 << ") have incompatible bit widths";
434 return success();
435}
436
437//===----------------------------------------------------------------------===//
438// ReturnOp
439//===----------------------------------------------------------------------===//
440
441void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {}
return success()
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
MLIRContext * getContext() const
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseKeywordOrString(std::string *result)
Parse a keyword or a quoted string.
ParseResult parseString(std::string *string)
Parse a quoted string token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual void printKeywordOrString(StringRef keyword)
Print the given string as a keyword, or a quoted and escaped string if it has any special or non-prin...
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
unsigned getNumArguments()
Definition Block.h:128
Operation & front()
Definition Block.h:153
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
Block * getSuccessor(unsigned i)
Definition Block.cpp:269
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:207
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
bool empty()
Definition Region.h:60
This class represents a collection of SymbolTables.
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This represents an operation in an abstracted form, suitable for use with the builder APIs.