MLIR  20.0.0git
PDLInterp.cpp
Go to the documentation of this file.
1 //===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/BuiltinTypes.h"
14 
15 using namespace mlir;
16 using namespace mlir::pdl_interp;
17 
18 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
19 
20 //===----------------------------------------------------------------------===//
21 // PDLInterp Dialect
22 //===----------------------------------------------------------------------===//
23 
24 void PDLInterpDialect::initialize() {
25  addOperations<
26 #define GET_OP_LIST
27 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
28  >();
29 }
30 
31 template <typename OpT>
32 static LogicalResult verifySwitchOp(OpT op) {
33  // Verify that the number of case destinations matches the number of case
34  // values.
35  size_t numDests = op.getCases().size();
36  size_t numValues = op.getCaseValues().size();
37  if (numDests != numValues) {
38  return op.emitOpError(
39  "expected number of cases to match the number of case "
40  "values, got ")
41  << numDests << " but expected " << numValues;
42  }
43  return success();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // pdl_interp::CreateOperationOp
48 //===----------------------------------------------------------------------===//
49 
50 LogicalResult CreateOperationOp::verify() {
51  if (!getInferredResultTypes())
52  return success();
53  if (!getInputResultTypes().empty()) {
54  return emitOpError("with inferred results cannot also have "
55  "explicit result types");
56  }
57  OperationName opName(getName(), getContext());
58  if (!opName.hasInterface<InferTypeOpInterface>()) {
59  return emitOpError()
60  << "has inferred results, but the created operation '" << opName
61  << "' does not support result type inference (or is not "
62  "registered)";
63  }
64  return success();
65 }
66 
68  OpAsmParser &p,
70  ArrayAttr &attrNamesAttr) {
71  Builder &builder = p.getBuilder();
72  SmallVector<Attribute, 4> attrNames;
73  if (succeeded(p.parseOptionalLBrace())) {
74  auto parseOperands = [&]() {
75  StringAttr nameAttr;
77  if (p.parseAttribute(nameAttr) || p.parseEqual() ||
78  p.parseOperand(operand))
79  return failure();
80  attrNames.push_back(nameAttr);
81  attrOperands.push_back(operand);
82  return success();
83  };
84  if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
85  return failure();
86  }
87  attrNamesAttr = builder.getArrayAttr(attrNames);
88  return success();
89 }
90 
92  CreateOperationOp op,
93  OperandRange attrArgs,
94  ArrayAttr attrNames) {
95  if (attrNames.empty())
96  return;
97  p << " {";
98  interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
99  [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
100  p << '}';
101 }
102 
103 static ParseResult parseCreateOperationOpResults(
104  OpAsmParser &p,
106  SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
107  if (failed(p.parseOptionalArrow()))
108  return success();
109 
110  // Handle the case of inferred results.
111  if (succeeded(p.parseOptionalLess())) {
112  if (p.parseKeyword("inferred") || p.parseGreater())
113  return failure();
114  inferredResultTypes = p.getBuilder().getUnitAttr();
115  return success();
116  }
117 
118  // Otherwise, parse the explicit results.
119  return failure(p.parseLParen() || p.parseOperandList(resultOperands) ||
120  p.parseColonTypeList(resultTypes) || p.parseRParen());
121 }
122 
123 static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
124  OperandRange resultOperands,
125  TypeRange resultTypes,
126  UnitAttr inferredResultTypes) {
127  // Handle the case of inferred results.
128  if (inferredResultTypes) {
129  p << " -> <inferred>";
130  return;
131  }
132 
133  // Otherwise, handle the explicit results.
134  if (!resultTypes.empty())
135  p << " -> (" << resultOperands << " : " << resultTypes << ")";
136 }
137 
138 //===----------------------------------------------------------------------===//
139 // pdl_interp::ForEachOp
140 //===----------------------------------------------------------------------===//
141 
142 void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
143  Value range, Block *successor, bool initLoop) {
144  build(builder, state, range, successor);
145  if (initLoop) {
146  // Create the block and the loop variable.
147  // FIXME: Allow passing in a proper location for the loop variable.
148  auto rangeType = llvm::cast<pdl::RangeType>(range.getType());
149  state.regions.front()->emplaceBlock();
150  state.regions.front()->addArgument(rangeType.getElementType(),
151  state.location);
152  }
153 }
154 
155 ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
156  // Parse the loop variable followed by type.
157  OpAsmParser::Argument loopVariable;
158  OpAsmParser::UnresolvedOperand operandInfo;
159  if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
160  parser.parseKeyword("in", " after loop variable") ||
161  // Parse the operand (value range).
162  parser.parseOperand(operandInfo))
163  return failure();
164 
165  // Resolve the operand.
166  Type rangeType = pdl::RangeType::get(loopVariable.type);
167  if (parser.resolveOperand(operandInfo, rangeType, result.operands))
168  return failure();
169 
170  // Parse the body region.
171  Region *body = result.addRegion();
172  Block *successor;
173  if (parser.parseRegion(*body, loopVariable) ||
174  parser.parseOptionalAttrDict(result.attributes) ||
175  // Parse the successor.
176  parser.parseArrow() || parser.parseSuccessor(successor))
177  return failure();
178 
179  result.addSuccessors(successor);
180  return success();
181 }
182 
184  BlockArgument arg = getLoopVariable();
185  p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
186  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
187  p.printOptionalAttrDict((*this)->getAttrs());
188  p << " -> ";
189  p.printSuccessor(getSuccessor());
190 }
191 
192 LogicalResult ForEachOp::verify() {
193  // Verify that the operation has exactly one argument.
194  if (getRegion().getNumArguments() != 1)
195  return emitOpError("requires exactly one argument");
196 
197  // Verify that the loop variable and the operand (value range)
198  // have compatible types.
199  BlockArgument arg = getLoopVariable();
200  Type rangeType = pdl::RangeType::get(arg.getType());
201  if (rangeType != getValues().getType())
202  return emitOpError("operand must be a range of loop variable type");
203 
204  return success();
205 }
206 
207 //===----------------------------------------------------------------------===//
208 // pdl_interp::FuncOp
209 //===----------------------------------------------------------------------===//
210 
211 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
212  FunctionType type, ArrayRef<NamedAttribute> attrs) {
213  buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
214 }
215 
216 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
217  auto buildFuncType =
218  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
220  std::string &) { return builder.getFunctionType(argTypes, results); };
221 
223  parser, result, /*allowVariadic=*/false,
224  getFunctionTypeAttrName(result.name), buildFuncType,
225  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
226 }
227 
228 void FuncOp::print(OpAsmPrinter &p) {
230  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
231  getArgAttrsAttrName(), getResAttrsAttrName());
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // pdl_interp::GetValueTypeOp
236 //===----------------------------------------------------------------------===//
237 
238 /// Given the result type of a `GetValueTypeOp`, return the expected input type.
240  Type valueTy = pdl::ValueType::get(type.getContext());
241  return llvm::isa<pdl::RangeType>(type) ? pdl::RangeType::get(valueTy)
242  : valueTy;
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // pdl::CreateRangeOp
247 //===----------------------------------------------------------------------===//
248 
249 static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
250  Type &resultType) {
251  // If arguments were provided, infer the result type from the argument list.
252  if (!argumentTypes.empty()) {
253  resultType =
255  return success();
256  }
257  // Otherwise, parse the type as a trailing type.
258  return p.parseColonType(resultType);
259 }
260 
261 static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
262  TypeRange argumentTypes, Type resultType) {
263  if (argumentTypes.empty())
264  p << ": " << resultType;
265 }
266 
267 LogicalResult CreateRangeOp::verify() {
268  Type elementType = getType().getElementType();
269  for (Type operandType : getOperandTypes()) {
270  Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
271  if (operandElementType != elementType) {
272  return emitOpError("expected operand to have element type ")
273  << elementType << ", but got " << operandElementType;
274  }
275  }
276  return success();
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // pdl_interp::SwitchAttributeOp
281 //===----------------------------------------------------------------------===//
282 
283 LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
284 
285 //===----------------------------------------------------------------------===//
286 // pdl_interp::SwitchOperandCountOp
287 //===----------------------------------------------------------------------===//
288 
289 LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
290 
291 //===----------------------------------------------------------------------===//
292 // pdl_interp::SwitchOperationNameOp
293 //===----------------------------------------------------------------------===//
294 
295 LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
296 
297 //===----------------------------------------------------------------------===//
298 // pdl_interp::SwitchResultCountOp
299 //===----------------------------------------------------------------------===//
300 
301 LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
302 
303 //===----------------------------------------------------------------------===//
304 // pdl_interp::SwitchTypeOp
305 //===----------------------------------------------------------------------===//
306 
307 LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
308 
309 //===----------------------------------------------------------------------===//
310 // pdl_interp::SwitchTypesOp
311 //===----------------------------------------------------------------------===//
312 
313 LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
314 
315 //===----------------------------------------------------------------------===//
316 // TableGen Auto-Generated Op and Interface Definitions
317 //===----------------------------------------------------------------------===//
318 
319 #define GET_OP_CLASSES
320 #include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
static MLIRContext * getContext(OpFoldResult val)
static void printRangeType(OpAsmPrinter &p, CreateRangeOp op, TypeRange argumentTypes, Type resultType)
Definition: PDLInterp.cpp:261
static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op, OperandRange resultOperands, TypeRange resultTypes, UnitAttr inferredResultTypes)
Definition: PDLInterp.cpp:123
static Type getGetValueTypeOpValueType(Type type)
Given the result type of a GetValueTypeOp, return the expected input type.
Definition: PDLInterp.cpp:239
static ParseResult parseCreateOperationOpResults(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &resultOperands, SmallVectorImpl< Type > &resultTypes, UnitAttr &inferredResultTypes)
Definition: PDLInterp.cpp:103
static ParseResult parseCreateOperationOpAttributes(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
Definition: PDLInterp.cpp:67
static LogicalResult verifySwitchOp(OpT op)
Definition: PDLInterp.cpp:32
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, Type &resultType)
Definition: PDLInterp.cpp:249
static void printCreateOperationOpAttributes(OpAsmPrinter &p, CreateOperationOp op, OperandRange attrArgs, ArrayAttr attrNames)
Definition: PDLInterp.cpp:91
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:120
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printSuccessor(Block *successor)=0
Print the given successor.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition: Builders.h:216
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Type getRangeElementTypeOrSelf(Type type)
If the given type is a range, return its element type, otherwise return the type itself.
Definition: PDLTypes.cpp:62
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.