MLIR  22.0.0git
WasmSSAOps.cpp
Go to the documentation of this file.
1 //===- WasmSSAOps.cpp - WasmSSA dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/Region.h"
17 #include "mlir/IR/SymbolTable.h"
19 #include "llvm/Support/Casting.h"
20 
21 //===----------------------------------------------------------------------===//
22 // TableGen'd op method definitions
23 //===----------------------------------------------------------------------===//
24 
25 using namespace mlir;
26 namespace {
27 ParseResult parseElseRegion(OpAsmParser &opParser, Region &elseRegion) {
28  std::string keyword;
29  std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
30  if (keyword == "else")
31  return opParser.parseRegion(elseRegion);
32  return ParseResult::success();
33 }
34 
35 void printElseRegion(OpAsmPrinter &opPrinter, Operation *op,
36  Region &elseRegion) {
37  if (elseRegion.empty())
38  return;
39  opPrinter.printKeywordOrString("else ");
40  opPrinter.printRegion(elseRegion);
41 }
42 
43 ParseResult parseWasmVisibility(OpAsmParser &opParser, StringAttr &visibility) {
44  std::string keyword;
45  auto initLocation = opParser.getCurrentLocation();
46  std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
47  if (keyword == "nested" or keyword == "") {
48  visibility = StringAttr::get(opParser.getContext(), "nested");
49  return ParseResult::success();
50  }
51 
52  if (keyword == "public" || keyword == "private") {
53  visibility = StringAttr::get(opParser.getContext(), keyword);
54  return ParseResult::success();
55  }
56  opParser.emitError(initLocation, "expecting symbol visibility");
57  return ParseResult::failure();
58 }
59 
60 void printWasmVisibility(OpAsmPrinter &opPrinter, Operation *op,
61  Attribute visibility) {
62  opPrinter.printKeywordOrString(cast<StringAttr>(visibility).strref());
63 }
64 } // namespace
65 
66 #define GET_OP_CLASSES
67 #include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.cpp.inc"
68 
70 #include "mlir/IR/Types.h"
71 #include "llvm/Support/LogicalResult.h"
72 
73 using namespace wasmssa;
74 
75 namespace {
76 inline LogicalResult
77 inferTeeGetResType(ValueRange operands,
78  SmallVectorImpl<Type> &inferredReturnTypes) {
79  if (operands.empty())
80  return failure();
81  auto opType = dyn_cast<LocalRefType>(operands.front().getType());
82  if (!opType)
83  return failure();
84  inferredReturnTypes.push_back(opType.getElementType());
85  return success();
86 }
87 
88 ParseResult parseImportOp(OpAsmParser &parser, OperationState &result) {
89  std::string importName;
90  auto *ctx = parser.getContext();
91  ParseResult res = parser.parseString(&importName);
92  result.addAttribute("importName", StringAttr::get(ctx, importName));
93 
94  std::string fromStr;
95  res = parser.parseKeywordOrString(&fromStr);
96  if (failed(res) || fromStr != "from")
97  return failure();
98 
99  std::string moduleName;
100  res = parser.parseString(&moduleName);
101  if (failed(res))
102  return failure();
103  result.addAttribute("moduleName", StringAttr::get(ctx, moduleName));
104 
105  std::string asStr;
106  res = parser.parseKeywordOrString(&asStr);
107  if (failed(res) || asStr != "as")
108  return failure();
109 
110  StringAttr symbolName;
111  res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
112  result.attributes);
113  return res;
114 }
115 } // namespace
116 
117 //===----------------------------------------------------------------------===//
118 // BlockOp
119 //===----------------------------------------------------------------------===//
120 
121 Block *BlockOp::getLabelTarget() { return getTarget(); }
122 
123 //===----------------------------------------------------------------------===//
124 // BlockReturnOp
125 //===----------------------------------------------------------------------===//
126 
127 std::size_t BlockReturnOp::getExitLevel() { return 0; }
128 
129 Block *BlockReturnOp::getTarget() {
130  return cast<LabelBranchingOpInterface>(getOperation())
131  .getTargetOp()
132  .getOperation()
133  ->getSuccessor(0);
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // ExtendLowBitsSOp
138 //===----------------------------------------------------------------------===//
139 
140 LogicalResult ExtendLowBitsSOp::verify() {
141  auto bitsToTake = getBitsToTake().getValue().getLimitedValue();
142  if (bitsToTake != 32 && bitsToTake != 16 && bitsToTake != 8)
143  return emitError("extend op can only take 8, 16 or 32 bits. Got ")
144  << bitsToTake;
145 
146  if (bitsToTake >= getInput().getType().getIntOrFloatBitWidth())
147  return emitError("trying to extend the ")
148  << bitsToTake << " low bits from a " << getInput().getType()
149  << " value is illegal";
150  return success();
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // FuncOp
155 //===----------------------------------------------------------------------===//
156 
157 Block *FuncOp::addEntryBlock() {
158  if (!getBody().empty()) {
159  emitError("adding entry block to a FuncOp which already has one");
160  return &getBody().front();
161  }
162  Block &block = getBody().emplaceBlock();
163  for (auto argType : getFunctionType().getInputs())
164  block.addArgument(LocalRefType::get(argType), getLoc());
165  return &block;
166 }
167 
168 void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
169  StringRef symbol, FunctionType funcType) {
170  FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {}, "nested");
171 }
172 
173 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
174  auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes,
175  ArrayRef<Type> results,
177  std::string &) {
178  SmallVector<Type> argTypesWithoutLocal{};
179  argTypesWithoutLocal.reserve(argTypes.size());
180  llvm::for_each(argTypes, [&parser, &argTypesWithoutLocal](Type argType) {
181  auto refType = dyn_cast<LocalRefType>(argType);
182  auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
183  if (!refType) {
184  mlir::emitError(loc, "invalid type for wasm.func argument. Expecting "
185  "!wasm<local T>, got ")
186  << argType;
187  return;
188  }
189  argTypesWithoutLocal.push_back(refType.getElementType());
190  });
191 
192  return builder.getFunctionType(argTypesWithoutLocal, results);
193  };
194 
196  parser, result, /*allowVariadic=*/false,
197  getFunctionTypeAttrName(result.name), buildFuncType,
198  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
199 }
200 
201 LogicalResult FuncOp::verifyBody() {
202  if (getBody().empty())
203  return success();
204  Block &entry = getBody().front();
205  if (entry.getNumArguments() != getFunctionType().getNumInputs())
206  return emitError("entry block should have same number of arguments as "
207  "function type. Function type has ")
208  << getFunctionType().getNumInputs() << ", entry block has "
209  << entry.getNumArguments();
210 
211  for (auto [argNo, funcSignatureType, blockType] : llvm::enumerate(
212  getFunctionType().getInputs(), entry.getArgumentTypes())) {
213  auto blockLocalRefType = dyn_cast<LocalRefType>(blockType);
214  if (!blockLocalRefType)
215  return emitError("entry block argument type should be LocalRefType, got ")
216  << blockType << " for block argument " << argNo;
217  if (blockLocalRefType.getElementType() != funcSignatureType)
218  return emitError("func argument type #")
219  << argNo << "(" << funcSignatureType
220  << ") doesn't match entry block referenced type ("
221  << blockLocalRefType.getElementType() << ")";
222  }
223  return success();
224 }
225 
226 void FuncOp::print(OpAsmPrinter &p) {
228  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
229  getArgAttrsAttrName(), getResAttrsAttrName());
230 }
231 
232 //===----------------------------------------------------------------------===//
233 // FuncImportOp
234 //===----------------------------------------------------------------------===//
235 
236 void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
237  StringRef symbol, StringRef moduleName,
238  StringRef importName, FunctionType type) {
239  FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
240  type, {}, {}, odsBuilder.getStringAttr("nested"));
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // GlobalOp
245 //===----------------------------------------------------------------------===//
246 
247 void GlobalOp::build(OpBuilder &odsBuilder, OperationState &odsState,
248  StringRef symbol, Type type, bool isMutable) {
249  GlobalOp::build(odsBuilder, odsState, symbol, type, isMutable,
250  odsBuilder.getStringAttr("nested"));
251 }
252 
253 // Custom formats
254 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
255  StringAttr symbolName;
256  Type globalType;
257  auto *ctx = parser.getContext();
258  ParseResult res = parser.parseSymbolName(
259  symbolName, SymbolTable::getSymbolAttrName(), result.attributes);
260 
261  res = parser.parseType(globalType);
262  result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType));
263  std::string mutableString;
264  res = parser.parseOptionalKeywordOrString(&mutableString);
265  if (res.succeeded() && mutableString == "mutable")
266  result.addAttribute("isMutable", UnitAttr::get(ctx));
267  std::string visibilityString;
268  res = parser.parseOptionalKeywordOrString(&visibilityString);
269  if (res.succeeded())
270  result.addAttribute("sym_visibility",
271  StringAttr::get(ctx, visibilityString));
272  res = parser.parseColon();
273  Region *globalInitRegion = result.addRegion();
274  res = parser.parseRegion(*globalInitRegion);
275  return res;
276 }
277 
278 void GlobalOp::print(OpAsmPrinter &printer) {
279  printer << " @" << getSymName().str() << " " << getType();
280  if (getIsMutable())
281  printer << " mutable";
282  if (auto vis = getSymVisibility())
283  printer << " " << *vis;
284  printer << " :";
285  Region &body = getRegion();
286  if (!body.empty()) {
287  printer << ' ';
288  printer.printRegion(body, /*printEntryBlockArgs=*/false,
289  /*printBlockTerminators=*/true);
290  }
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // GlobalGetOp
295 //===----------------------------------------------------------------------===//
296 
297 LogicalResult
298 GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
299  // If the parent requires a constant context, verify that global.get is a
300  // constant as defined per the wasm standard.
301  if (!this->getOperation()
302  ->getParentWithTrait<ConstantExpressionInitializerOpTrait>())
303  return success();
305  StringRef referencedSymbol = getGlobal();
306  Operation *definitionOp = symbolTable.lookupSymbolIn(
307  symTabOp, StringAttr::get(this->getContext(), referencedSymbol));
308  if (!definitionOp)
309  return emitError() << "symbol @" << referencedSymbol << " is undefined";
310  auto definitionImport = dyn_cast<GlobalImportOp>(definitionOp);
311  if (!definitionImport || definitionImport.getIsMutable()) {
312  return emitError("global.get op is considered constant if it's referring "
313  "to a import.global symbol marked non-mutable");
314  }
315  return success();
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // GlobalImportOp
320 //===----------------------------------------------------------------------===//
321 
322 void GlobalImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
323  StringRef symbol, StringRef moduleName,
324  StringRef importName, Type type, bool isMutable) {
325  GlobalImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
326  type, isMutable, odsBuilder.getStringAttr("nested"));
327 }
328 
329 ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
330  auto *ctx = parser.getContext();
331  ParseResult res = parseImportOp(parser, result);
332  if (res.failed())
333  return failure();
334  std::string mutableOrSymVisString;
335  res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
336  if (res.succeeded() && mutableOrSymVisString == "mutable") {
337  result.addAttribute("isMutable", UnitAttr::get(ctx));
338  res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
339  }
340 
341  if (res.succeeded())
342  result.addAttribute("sym_visibility",
343  StringAttr::get(ctx, mutableOrSymVisString));
344  res = parser.parseColon();
345 
346  Type importedType;
347  res = parser.parseType(importedType);
348  if (res.succeeded())
349  result.addAttribute(getTypeAttrName(result.name),
350  TypeAttr::get(importedType));
351  return res;
352 }
353 
354 void GlobalImportOp::print(OpAsmPrinter &printer) {
355  printer << " \"" << getImportName() << "\" from \"" << getModuleName()
356  << "\" as @" << getSymName();
357  if (getIsMutable())
358  printer << " mutable";
359  if (auto vis = getSymVisibility())
360  printer << " " << *vis;
361  printer << " : " << getType();
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // IfOp
366 //===----------------------------------------------------------------------===//
367 
368 Block *IfOp::getLabelTarget() { return getTarget(); }
369 
370 //===----------------------------------------------------------------------===//
371 // LocalOp
372 //===----------------------------------------------------------------------===//
373 
374 LogicalResult LocalOp::inferReturnTypes(
375  MLIRContext *context, ::std::optional<Location> location,
376  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
377  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
378  LocalOp::GenericAdaptor<ValueRange> adaptor{operands, attributes, properties,
379  regions};
380  auto type = adaptor.getTypeAttr();
381  if (!type)
382  return failure();
383  auto resType = LocalRefType::get(type.getContext(), type.getValue());
384  inferredReturnTypes.push_back(resType);
385  return success();
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // LocalGetOp
390 //===----------------------------------------------------------------------===//
391 
392 LogicalResult LocalGetOp::inferReturnTypes(
393  MLIRContext *context, ::std::optional<Location> location,
394  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
395  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
396  return inferTeeGetResType(operands, inferredReturnTypes);
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // LocalSetOp
401 //===----------------------------------------------------------------------===//
402 
403 LogicalResult LocalSetOp::verify() {
404  if (getLocalVar().getType().getElementType() != getValue().getType())
405  return emitError("input type and result type of local.set do not match");
406  return success();
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // LocalTeeOp
411 //===----------------------------------------------------------------------===//
412 
413 LogicalResult LocalTeeOp::inferReturnTypes(
414  MLIRContext *context, ::std::optional<Location> location,
415  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
416  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
417  return inferTeeGetResType(operands, inferredReturnTypes);
418 }
419 
420 LogicalResult LocalTeeOp::verify() {
421  if (getLocalVar().getType().getElementType() != getValue().getType() ||
422  getValue().getType() != getResult().getType())
423  return emitError("input type and output type of local.tee do not match");
424  return success();
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // LoopOp
429 //===----------------------------------------------------------------------===//
430 
431 Block *LoopOp::getLabelTarget() { return &getBody().front(); }
432 
433 //===----------------------------------------------------------------------===//
434 // MemOp
435 //===----------------------------------------------------------------------===//
436 
437 void MemOp::build(OpBuilder &odsBuilder, OperationState &odsState,
438  StringRef symbol, LimitType limit) {
439  MemOp::build(odsBuilder, odsState, symbol, limit,
440  odsBuilder.getStringAttr("nested"));
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // MemImportOp
445 //===----------------------------------------------------------------------===//
446 
447 void MemImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
448  StringRef symbol, StringRef moduleName,
449  StringRef importName, LimitType limits) {
450  MemImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
451  limits, odsBuilder.getStringAttr("nested"));
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // ReinterpretOp
456 //===----------------------------------------------------------------------===//
457 
458 LogicalResult ReinterpretOp::verify() {
459  auto inT = getInput().getType();
460  auto resT = getResult().getType();
461  if (inT == resT)
462  return emitError("reinterpret input and output type should be distinct");
463  if (inT.getIntOrFloatBitWidth() != resT.getIntOrFloatBitWidth())
464  return emitError() << "input type (" << inT << ") and output type (" << resT
465  << ") have incompatible bit widths";
466  return success();
467 }
468 
469 //===----------------------------------------------------------------------===//
470 // ReturnOp
471 //===----------------------------------------------------------------------===//
472 
473 void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {}
474 
475 //===----------------------------------------------------------------------===//
476 // TableOp
477 //===----------------------------------------------------------------------===//
478 
479 void TableOp::build(OpBuilder &odsBuilder, OperationState &odsState,
480  StringRef symbol, TableType type) {
481  TableOp::build(odsBuilder, odsState, symbol, type,
482  odsBuilder.getStringAttr("nested"));
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // TableImportOp
487 //===----------------------------------------------------------------------===//
488 
489 void TableImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
490  StringRef symbol, StringRef moduleName,
491  StringRef importName, TableType type) {
492  TableImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
493  type, odsBuilder.getStringAttr("nested"));
494 }
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseKeywordOrString(std::string *result)
Parse a keyword or a quoted string.
ParseResult parseString(std::string *string)
Parse a quoted string token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual void printKeywordOrString(StringRef keyword)
Print the given string as a keyword, or a quoted and escaped string if it has any special or non-prin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
unsigned getNumArguments()
Definition: Block.h:128
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
Operation & front()
Definition: Block.h:153
Block * getSuccessor(unsigned i)
Definition: Block.cpp:269
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:205
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
A named class for passing around the variadic flag.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.