MLIR  22.0.0git
WasmSSAOps.cpp
Go to the documentation of this file.
1 //===- WasmSSAOps.cpp - WasmSSA dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/Diagnostics.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/Region.h"
18 #include "mlir/IR/SymbolTable.h"
20 #include "llvm/Support/Casting.h"
21 
22 //===----------------------------------------------------------------------===//
23 // TableGen'd op method definitions
24 //===----------------------------------------------------------------------===//
25 
26 using namespace mlir;
27 namespace {
28 ParseResult parseElseRegion(OpAsmParser &opParser, Region &elseRegion) {
29  std::string keyword;
30  std::ignore = opParser.parseOptionalKeywordOrString(&keyword);
31  if (keyword == "else")
32  return opParser.parseRegion(elseRegion);
33  return ParseResult::success();
34 }
35 
36 void printElseRegion(OpAsmPrinter &opPrinter, Operation *op,
37  Region &elseRegion) {
38  if (elseRegion.empty())
39  return;
40  opPrinter.printKeywordOrString("else ");
41  opPrinter.printRegion(elseRegion);
42 }
43 } // namespace
44 
45 #define GET_OP_CLASSES
46 #include "mlir/Dialect/WasmSSA/IR/WasmSSAOps.cpp.inc"
47 
49 #include "mlir/IR/Types.h"
50 #include "llvm/Support/LogicalResult.h"
51 
52 using namespace wasmssa;
53 
54 namespace {
55 inline LogicalResult
56 inferTeeGetResType(ValueRange operands,
57  SmallVectorImpl<Type> &inferredReturnTypes) {
58  if (operands.empty())
59  return failure();
60  auto opType = dyn_cast<LocalRefType>(operands.front().getType());
61  if (!opType)
62  return failure();
63  inferredReturnTypes.push_back(opType.getElementType());
64  return success();
65 }
66 
67 ParseResult parseImportOp(OpAsmParser &parser, OperationState &result) {
68  std::string importName;
69  auto *ctx = parser.getContext();
70  ParseResult res = parser.parseString(&importName);
71  result.addAttribute("importName", StringAttr::get(ctx, importName));
72 
73  std::string fromStr;
74  res = parser.parseKeywordOrString(&fromStr);
75  if (failed(res) || fromStr != "from")
76  return failure();
77 
78  std::string moduleName;
79  res = parser.parseString(&moduleName);
80  if (failed(res))
81  return failure();
82  result.addAttribute("moduleName", StringAttr::get(ctx, moduleName));
83 
84  std::string asStr;
85  res = parser.parseKeywordOrString(&asStr);
86  if (failed(res) || asStr != "as")
87  return failure();
88 
89  StringAttr symbolName;
90  res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
91  result.attributes);
92  return res;
93 }
94 } // namespace
95 
96 //===----------------------------------------------------------------------===//
97 // BlockOp
98 //===----------------------------------------------------------------------===//
99 
100 Block *BlockOp::getLabelTarget() { return getTarget(); }
101 
102 //===----------------------------------------------------------------------===//
103 // BlockReturnOp
104 //===----------------------------------------------------------------------===//
105 
106 std::size_t BlockReturnOp::getExitLevel() { return 0; }
107 
108 Block *BlockReturnOp::getTarget() {
109  return cast<LabelBranchingOpInterface>(getOperation())
110  .getTargetOp()
111  .getOperation()
112  ->getSuccessor(0);
113 }
114 
115 //===----------------------------------------------------------------------===//
116 // ExtendLowBitsSOp
117 //===----------------------------------------------------------------------===//
118 
119 LogicalResult ExtendLowBitsSOp::verify() {
120  auto bitsToTake = getBitsToTake().getValue().getLimitedValue();
121  if (bitsToTake != 32 && bitsToTake != 16 && bitsToTake != 8)
122  return emitError("extend op can only take 8, 16 or 32 bits. Got ")
123  << bitsToTake;
124 
125  if (bitsToTake >= getInput().getType().getIntOrFloatBitWidth())
126  return emitError("trying to extend the ")
127  << bitsToTake << " low bits from a " << getInput().getType()
128  << " value is illegal";
129  return success();
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // FuncOp
134 //===----------------------------------------------------------------------===//
135 
136 Block *FuncOp::addEntryBlock() {
137  if (!getBody().empty()) {
138  emitError("adding entry block to a FuncOp which already has one");
139  return &getBody().front();
140  }
141  Block &block = getBody().emplaceBlock();
142  for (auto argType : getFunctionType().getInputs())
143  block.addArgument(LocalRefType::get(argType), getLoc());
144  return &block;
145 }
146 
147 void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
148  StringRef symbol, FunctionType funcType) {
149  FuncOp::build(odsBuilder, odsState, symbol, funcType, {}, {});
150 }
151 
152 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
153  auto *ctx = parser.getContext();
154  std::string visibilityString;
155  auto loc = parser.getNameLoc();
156  ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
157  bool exported{false};
158  if (res.succeeded()) {
159  if (visibilityString != "exported")
160  return parser.emitError(
161  loc, "expecting either `exported` or symbol name. got ")
162  << visibilityString;
163  exported = true;
164  }
165 
166  auto buildFuncType = [&parser](Builder &builder, ArrayRef<Type> argTypes,
167  ArrayRef<Type> results,
169  std::string &) {
170  SmallVector<Type> argTypesWithoutLocal{};
171  argTypesWithoutLocal.reserve(argTypes.size());
172  llvm::for_each(argTypes, [&parser, &argTypesWithoutLocal](Type argType) {
173  auto refType = dyn_cast<LocalRefType>(argType);
174  auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
175  if (!refType) {
176  mlir::emitError(loc, "invalid type for wasm.func argument. Expecting "
177  "!wasm<local T>, got ")
178  << argType;
179  return;
180  }
181  argTypesWithoutLocal.push_back(refType.getElementType());
182  });
183 
184  return builder.getFunctionType(argTypesWithoutLocal, results);
185  };
186  auto funcParseRes = function_interface_impl::parseFunctionOp(
187  parser, result, /*allowVariadic=*/false,
188  getFunctionTypeAttrName(result.name), buildFuncType,
189  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
190  if (exported)
191  result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
192  return funcParseRes;
193 }
194 
195 LogicalResult FuncOp::verifyBody() {
196  if (getBody().empty())
197  return success();
198  Block &entry = getBody().front();
199  if (entry.getNumArguments() != getFunctionType().getNumInputs())
200  return emitError("entry block should have same number of arguments as "
201  "function type. Function type has ")
202  << getFunctionType().getNumInputs() << ", entry block has "
203  << entry.getNumArguments();
204 
205  for (auto [argNo, funcSignatureType, blockType] : llvm::enumerate(
206  getFunctionType().getInputs(), entry.getArgumentTypes())) {
207  auto blockLocalRefType = dyn_cast<LocalRefType>(blockType);
208  if (!blockLocalRefType)
209  return emitError("entry block argument type should be LocalRefType, got ")
210  << blockType << " for block argument " << argNo;
211  if (blockLocalRefType.getElementType() != funcSignatureType)
212  return emitError("func argument type #")
213  << argNo << "(" << funcSignatureType
214  << ") doesn't match entry block referenced type ("
215  << blockLocalRefType.getElementType() << ")";
216  }
217  return success();
218 }
219 
220 void FuncOp::print(OpAsmPrinter &p) {
221  /// If exported, print it before and mask it before printing
222  /// using generic interface.
223  auto exported = getExported();
224  if (exported) {
225  p << " exported";
226  removeExportedAttr();
227  }
229  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
230  getArgAttrsAttrName(), getResAttrsAttrName());
231  if (exported)
232  setExported(true);
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // FuncImportOp
237 //===----------------------------------------------------------------------===//
238 
239 void FuncImportOp::build(OpBuilder &odsBuilder, OperationState &odsState,
240  StringRef symbol, StringRef moduleName,
241  StringRef importName, FunctionType type) {
242  FuncImportOp::build(odsBuilder, odsState, symbol, moduleName, importName,
243  type, {}, {});
244 }
245 
246 //===----------------------------------------------------------------------===//
247 // GlobalOp
248 //===----------------------------------------------------------------------===//
249 // Custom formats
250 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
251  StringAttr symbolName;
252  Type globalType;
253  auto *ctx = parser.getContext();
254  std::string visibilityString;
255  auto loc = parser.getNameLoc();
256  ParseResult res = parser.parseOptionalKeywordOrString(&visibilityString);
257  if (res.succeeded()) {
258  if (visibilityString != "exported")
259  return parser.emitError(
260  loc, "expecting either `exported` or symbol name. got ")
261  << visibilityString;
262  result.addAttribute(getExportedAttrName(result.name), UnitAttr::get(ctx));
263  }
264 
265  res = parser.parseSymbolName(symbolName, SymbolTable::getSymbolAttrName(),
266  result.attributes);
267  res = parser.parseType(globalType);
268  result.addAttribute(getTypeAttrName(result.name), TypeAttr::get(globalType));
269  std::string mutableString;
270  res = parser.parseOptionalKeywordOrString(&mutableString);
271  if (res.succeeded() && mutableString == "mutable")
272  result.addAttribute("isMutable", UnitAttr::get(ctx));
273 
274  res = parser.parseColon();
275  Region *globalInitRegion = result.addRegion();
276  res = parser.parseRegion(*globalInitRegion);
277  return res;
278 }
279 
280 void GlobalOp::print(OpAsmPrinter &printer) {
281  if (getExported())
282  printer << " exported";
283  printer << " @" << getSymName().str() << " " << getType();
284  if (getIsMutable())
285  printer << " mutable";
286  printer << " :";
287  Region &body = getRegion();
288  if (!body.empty()) {
289  printer << ' ';
290  printer.printRegion(body, /*printEntryBlockArgs=*/false,
291  /*printBlockTerminators=*/true);
292  }
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // GlobalGetOp
297 //===----------------------------------------------------------------------===//
298 
299 LogicalResult
300 GlobalGetOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
301  // If the parent requires a constant context, verify that global.get is a
302  // constant as defined per the wasm standard.
303  if (!this->getOperation()
304  ->getParentWithTrait<ConstantExpressionInitializerOpTrait>())
305  return success();
307  StringRef referencedSymbol = getGlobal();
308  Operation *definitionOp = symbolTable.lookupSymbolIn(
309  symTabOp, StringAttr::get(this->getContext(), referencedSymbol));
310  if (!definitionOp)
311  return emitError() << "symbol @" << referencedSymbol << " is undefined";
312  auto definitionImport = dyn_cast<GlobalImportOp>(definitionOp);
313  if (!definitionImport || definitionImport.getIsMutable()) {
314  return emitError("global.get op is considered constant if it's referring "
315  "to a import.global symbol marked non-mutable");
316  }
317  return success();
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // GlobalImportOp
322 //===----------------------------------------------------------------------===//
323 
324 ParseResult GlobalImportOp::parse(OpAsmParser &parser, OperationState &result) {
325  auto *ctx = parser.getContext();
326  ParseResult res = parseImportOp(parser, result);
327  if (res.failed())
328  return failure();
329  std::string mutableOrSymVisString;
330  res = parser.parseOptionalKeywordOrString(&mutableOrSymVisString);
331  if (res.succeeded() && mutableOrSymVisString == "mutable") {
332  result.addAttribute("isMutable", UnitAttr::get(ctx));
333  }
334 
335  res = parser.parseColon();
336 
337  Type importedType;
338  res = parser.parseType(importedType);
339  if (res.succeeded())
340  result.addAttribute(getTypeAttrName(result.name),
341  TypeAttr::get(importedType));
342  return res;
343 }
344 
345 void GlobalImportOp::print(OpAsmPrinter &printer) {
346  printer << " \"" << getImportName() << "\" from \"" << getModuleName()
347  << "\" as @" << getSymName();
348  if (getIsMutable())
349  printer << " mutable";
350  printer << " : " << getType();
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // IfOp
355 //===----------------------------------------------------------------------===//
356 
357 Block *IfOp::getLabelTarget() { return getTarget(); }
358 
359 //===----------------------------------------------------------------------===//
360 // LocalOp
361 //===----------------------------------------------------------------------===//
362 
363 LogicalResult LocalOp::inferReturnTypes(
364  MLIRContext *context, ::std::optional<Location> location,
365  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
366  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
367  LocalOp::GenericAdaptor<ValueRange> adaptor{operands, attributes, properties,
368  regions};
369  auto type = adaptor.getTypeAttr();
370  if (!type)
371  return failure();
372  auto resType = LocalRefType::get(type.getContext(), type.getValue());
373  inferredReturnTypes.push_back(resType);
374  return success();
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // LocalGetOp
379 //===----------------------------------------------------------------------===//
380 
381 LogicalResult LocalGetOp::inferReturnTypes(
382  MLIRContext *context, ::std::optional<Location> location,
383  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
384  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
385  return inferTeeGetResType(operands, inferredReturnTypes);
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // LocalSetOp
390 //===----------------------------------------------------------------------===//
391 
392 LogicalResult LocalSetOp::verify() {
393  if (getLocalVar().getType().getElementType() != getValue().getType())
394  return emitError("input type and result type of local.set do not match");
395  return success();
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // LocalTeeOp
400 //===----------------------------------------------------------------------===//
401 
402 LogicalResult LocalTeeOp::inferReturnTypes(
403  MLIRContext *context, ::std::optional<Location> location,
404  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
405  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
406  return inferTeeGetResType(operands, inferredReturnTypes);
407 }
408 
409 LogicalResult LocalTeeOp::verify() {
410  if (getLocalVar().getType().getElementType() != getValue().getType() ||
411  getValue().getType() != getResult().getType())
412  return emitError("input type and output type of local.tee do not match");
413  return success();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // LoopOp
418 //===----------------------------------------------------------------------===//
419 
420 Block *LoopOp::getLabelTarget() { return &getBody().front(); }
421 
422 //===----------------------------------------------------------------------===//
423 // ReinterpretOp
424 //===----------------------------------------------------------------------===//
425 
426 LogicalResult ReinterpretOp::verify() {
427  auto inT = getInput().getType();
428  auto resT = getResult().getType();
429  if (inT == resT)
430  return emitError("reinterpret input and output type should be distinct");
431  if (inT.getIntOrFloatBitWidth() != resT.getIntOrFloatBitWidth())
432  return emitError() << "input type (" << inT << ") and output type (" << resT
433  << ") have incompatible bit widths";
434  return success();
435 }
436 
437 //===----------------------------------------------------------------------===//
438 // ReturnOp
439 //===----------------------------------------------------------------------===//
440 
441 void ReturnOp::build(OpBuilder &odsBuilder, OperationState &odsState) {}
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual Location getEncodedSourceLoc(SMLoc loc)=0
Re-encode the given source location as an MLIR location and return it.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseKeywordOrString(std::string *result)
Parse a keyword or a quoted string.
ParseResult parseString(std::string *string)
Parse a quoted string token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual void printKeywordOrString(StringRef keyword)
Print the given string as a keyword, or a quoted and escaped string if it has any special or non-prin...
Block represents an ordered list of Operations.
Definition: Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition: Block.cpp:149
unsigned getNumArguments()
Definition: Block.h:128
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
Operation & front()
Definition: Block.h:153
Block * getSuccessor(unsigned i)
Definition: Block.cpp:269
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:207
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
A named class for passing around the variadic flag.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.