MLIR  22.0.0git
PDLInterp.cpp
Go to the documentation of this file.
1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinTypes.h"
13 
14 using namespace mlir;
15 using namespace mlir::pdl_interp;
16 
17 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
18 
19 //===----------------------------------------------------------------------===//
20 // PDLInterp Dialect
21 //===----------------------------------------------------------------------===//
22 
23 void PDLInterpDialect::initialize() {
24  addOperations<
25 #define GET_OP_LIST
26 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
27  >();
28 }
29 
30 template <typename OpT>
31 static LogicalResult verifySwitchOp(OpT op) {
32  // Verify that the number of case destinations matches the number of case
33  // values.
34  size_t numDests = op.getCases().size();
35  size_t numValues = op.getCaseValues().size();
36  if (numDests != numValues) {
37  return op.emitOpError(
38  "expected number of cases to match the number of case "
39  "values, got ")
40  << numDests << " but expected " << numValues;
41  }
42  return success();
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // pdl_interp::CreateOperationOp
47 //===----------------------------------------------------------------------===//
48 
49 LogicalResult CreateOperationOp::verify() {
50  if (!getInferredResultTypes())
51  return success();
52  if (!getInputResultTypes().empty()) {
53  return emitOpError("with inferred results cannot also have "
54  "explicit result types");
55  }
56  OperationName opName(getName(), getContext());
57  if (!opName.hasInterface<InferTypeOpInterface>()) {
58  return emitOpError()
59  << "has inferred results, but the created operation '" << opName
60  << "' does not support result type inference (or is not "
61  "registered)";
62  }
63  return success();
64 }
65 
67  OpAsmParser &p,
69  ArrayAttr &attrNamesAttr) {
70  Builder &builder = p.getBuilder();
71  SmallVector<Attribute, 4> attrNames;
72  if (succeeded(p.parseOptionalLBrace())) {
73  auto parseOperands = [&]() {
74  StringAttr nameAttr;
76  if (p.parseAttribute(nameAttr) || p.parseEqual() ||
77  p.parseOperand(operand))
78  return failure();
79  attrNames.push_back(nameAttr);
80  attrOperands.push_back(operand);
81  return success();
82  };
83  if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
84  return failure();
85  }
86  attrNamesAttr = builder.getArrayAttr(attrNames);
87  return success();
88 }
89 
91  CreateOperationOp op,
92  OperandRange attrArgs,
93  ArrayAttr attrNames) {
94  if (attrNames.empty())
95  return;
96  p << " {";
97  interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
98  [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
99  p << '}';
100 }
101 
102 static ParseResult parseCreateOperationOpResults(
103  OpAsmParser &p,
105  SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
106  if (failed(p.parseOptionalArrow()))
107  return success();
108 
109  // Handle the case of inferred results.
110  if (succeeded(p.parseOptionalLess())) {
111  if (p.parseKeyword("inferred") || p.parseGreater())
112  return failure();
113  inferredResultTypes = p.getBuilder().getUnitAttr();
114  return success();
115  }
116 
117  // Otherwise, parse the explicit results.
118  return failure(p.parseLParen() || p.parseOperandList(resultOperands) ||
119  p.parseColonTypeList(resultTypes) || p.parseRParen());
120 }
121 
122 static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
123  OperandRange resultOperands,
124  TypeRange resultTypes,
125  UnitAttr inferredResultTypes) {
126  // Handle the case of inferred results.
127  if (inferredResultTypes) {
128  p << " -> <inferred>";
129  return;
130  }
131 
132  // Otherwise, handle the explicit results.
133  if (!resultTypes.empty())
134  p << " -> (" << resultOperands << " : " << resultTypes << ")";
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // pdl_interp::ForEachOp
139 //===----------------------------------------------------------------------===//
140 
141 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
142  Value range, Block *successor, bool initLoop) {
143  build(builder, state, range, successor);
144  if (initLoop) {
145  // Create the block and the loop variable.
146  // FIXME: Allow passing in a proper location for the loop variable.
147  auto rangeType = llvm::cast<pdl::RangeType>(range.getType());
148  state.regions.front()->emplaceBlock();
149  state.regions.front()->addArgument(rangeType.getElementType(),
150  state.location);
151  }
152 }
153 
154 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
155  // Parse the loop variable followed by type.
156  OpAsmParser::Argument loopVariable;
157  OpAsmParser::UnresolvedOperand operandInfo;
158  if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
159  parser.parseKeyword("in", " after loop variable") ||
160  // Parse the operand (value range).
161  parser.parseOperand(operandInfo))
162  return failure();
163 
164  // Resolve the operand.
165  Type rangeType = pdl::RangeType::get(loopVariable.type);
166  if (parser.resolveOperand(operandInfo, rangeType, result.operands))
167  return failure();
168 
169  // Parse the body region.
170  Region *body = result.addRegion();
171  Block *successor;
172  if (parser.parseRegion(*body, loopVariable) ||
173  parser.parseOptionalAttrDict(result.attributes) ||
174  // Parse the successor.
175  parser.parseArrow() || parser.parseSuccessor(successor))
176  return failure();
177 
178  result.addSuccessors(successor);
179  return success();
180 }
181 
183  BlockArgument arg = getLoopVariable();
184  p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
185  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
186  p.printOptionalAttrDict((*this)->getAttrs());
187  p << " -> ";
188  p.printSuccessor(getSuccessor());
189 }
190 
191 LogicalResult ForEachOp::verify() {
192  // Verify that the operation has exactly one argument.
193  if (getRegion().getNumArguments() != 1)
194  return emitOpError("requires exactly one argument");
195 
196  // Verify that the loop variable and the operand (value range)
197  // have compatible types.
198  BlockArgument arg = getLoopVariable();
199  Type rangeType = pdl::RangeType::get(arg.getType());
200  if (rangeType != getValues().getType())
201  return emitOpError("operand must be a range of loop variable type");
202 
203  return success();
204 }
205 
206 //===----------------------------------------------------------------------===//
207 // pdl_interp::FuncOp
208 //===----------------------------------------------------------------------===//
209 
210 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
211  FunctionType type, ArrayRef<NamedAttribute> attrs) {
212  buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
213 }
214 
215 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
216  auto buildFuncType =
217  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
219  std::string &) { return builder.getFunctionType(argTypes, results); };
220 
222  parser, result, /*allowVariadic=*/false,
223  getFunctionTypeAttrName(result.name), buildFuncType,
224  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
225 }
226 
227 void FuncOp::print(OpAsmPrinter &p) {
229  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
230  getArgAttrsAttrName(), getResAttrsAttrName());
231 }
232 
233 //===----------------------------------------------------------------------===//
234 // pdl_interp::GetValueTypeOp
235 //===----------------------------------------------------------------------===//
236 
237 /// Given the result type of a `GetValueTypeOp`, return the expected input type.
239  Type valueTy = pdl::ValueType::get(type.getContext());
240  return llvm::isa<pdl::RangeType>(type) ? pdl::RangeType::get(valueTy)
241  : valueTy;
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // pdl::CreateRangeOp
246 //===----------------------------------------------------------------------===//
247 
248 static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
249  Type &resultType) {
250  // If arguments were provided, infer the result type from the argument list.
251  if (!argumentTypes.empty()) {
252  resultType =
254  return success();
255  }
256  // Otherwise, parse the type as a trailing type.
257  return p.parseColonType(resultType);
258 }
259 
260 static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
261  TypeRange argumentTypes, Type resultType) {
262  if (argumentTypes.empty())
263  p << ": " << resultType;
264 }
265 
266 LogicalResult CreateRangeOp::verify() {
267  Type elementType = getType().getElementType();
268  for (Type operandType : getOperandTypes()) {
269  Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
270  if (operandElementType != elementType) {
271  return emitOpError("expected operand to have element type ")
272  << elementType << ", but got " << operandElementType;
273  }
274  }
275  return success();
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // pdl_interp::SwitchAttributeOp
280 //===----------------------------------------------------------------------===//
281 
282 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
283 
284 //===----------------------------------------------------------------------===//
285 // pdl_interp::SwitchOperandCountOp
286 //===----------------------------------------------------------------------===//
287 
288 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
289 
290 //===----------------------------------------------------------------------===//
291 // pdl_interp::SwitchOperationNameOp
292 //===----------------------------------------------------------------------===//
293 
294 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
295 
296 //===----------------------------------------------------------------------===//
297 // pdl_interp::SwitchResultCountOp
298 //===----------------------------------------------------------------------===//
299 
300 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
301 
302 //===----------------------------------------------------------------------===//
303 // pdl_interp::SwitchTypeOp
304 //===----------------------------------------------------------------------===//
305 
306 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
307 
308 //===----------------------------------------------------------------------===//
309 // pdl_interp::SwitchTypesOp
310 //===----------------------------------------------------------------------===//
311 
312 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
313 
314 //===----------------------------------------------------------------------===//
315 // TableGen Auto-Generated Op and Interface Definitions
316 //===----------------------------------------------------------------------===//
317 
318 #define GET_OP_CLASSES
319 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static void printRangeType(OpAsmPrinter &p, CreateRangeOp op, TypeRange argumentTypes, Type resultType)
Definition: PDLInterp.cpp:260
static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op, OperandRange resultOperands, TypeRange resultTypes, UnitAttr inferredResultTypes)
Definition: PDLInterp.cpp:122
static Type getGetValueTypeOpValueType(Type type)
Given the result type of a GetValueTypeOp, return the expected input type.
Definition: PDLInterp.cpp:238
static ParseResult parseCreateOperationOpResults(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &resultOperands, SmallVectorImpl< Type > &resultTypes, UnitAttr &inferredResultTypes)
Definition: PDLInterp.cpp:102
static ParseResult parseCreateOperationOpAttributes(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
Definition: PDLInterp.cpp:66
static LogicalResult verifySwitchOp(OpT op)
Definition: PDLInterp.cpp:31
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, Type &resultType)
Definition: PDLInterp.cpp:248
static void printCreateOperationOpAttributes(OpAsmPrinter &p, CreateOperationOp op, OperandRange attrArgs, ArrayAttr attrNames)
Definition: PDLInterp.cpp:90
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printSuccessor(Block *successor)=0
Print the given successor.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:207
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Type getRangeElementTypeOrSelf(Type type)
If the given type is a range, return its element type, otherwise return the type itself.
Definition: PDLTypes.cpp:62
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.