MLIR 22.0.0git
PDLInterp.cpp
Go to the documentation of this file.
1//===- PDLInterp.cpp - PDL Interpreter Dialect ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13
14using namespace mlir;
15using namespace mlir::pdl_interp;
16
17#include "mlir/Dialect/PDLInterp/IR/PDLInterpOpsDialect.cpp.inc"
18
19//===----------------------------------------------------------------------===//
20// PDLInterp Dialect
21//===----------------------------------------------------------------------===//
22
23void PDLInterpDialect::initialize() {
24 addOperations<
25#define GET_OP_LIST
26#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
27 >();
28}
29
30template <typename OpT>
31static LogicalResult verifySwitchOp(OpT op) {
32 // Verify that the number of case destinations matches the number of case
33 // values.
34 size_t numDests = op.getCases().size();
35 size_t numValues = op.getCaseValues().size();
36 if (numDests != numValues) {
37 return op.emitOpError(
38 "expected number of cases to match the number of case "
39 "values, got ")
40 << numDests << " but expected " << numValues;
41 }
42 return success();
43}
44
45//===----------------------------------------------------------------------===//
46// pdl_interp::CreateOperationOp
47//===----------------------------------------------------------------------===//
48
49LogicalResult CreateOperationOp::verify() {
50 if (!getInferredResultTypes())
51 return success();
52 if (!getInputResultTypes().empty()) {
53 return emitOpError("with inferred results cannot also have "
54 "explicit result types");
55 }
56 OperationName opName(getName(), getContext());
57 if (!opName.hasInterface<InferTypeOpInterface>()) {
58 return emitOpError()
59 << "has inferred results, but the created operation '" << opName
60 << "' does not support result type inference (or is not "
61 "registered)";
62 }
63 return success();
64}
65
67 OpAsmParser &p,
69 ArrayAttr &attrNamesAttr) {
70 Builder &builder = p.getBuilder();
72 if (succeeded(p.parseOptionalLBrace())) {
73 auto parseOperands = [&]() {
74 StringAttr nameAttr;
76 if (p.parseAttribute(nameAttr) || p.parseEqual() ||
77 p.parseOperand(operand))
78 return failure();
79 attrNames.push_back(nameAttr);
80 attrOperands.push_back(operand);
81 return success();
82 };
83 if (p.parseCommaSeparatedList(parseOperands) || p.parseRBrace())
84 return failure();
85 }
86 attrNamesAttr = builder.getArrayAttr(attrNames);
87 return success();
88}
89
91 CreateOperationOp op,
92 OperandRange attrArgs,
93 ArrayAttr attrNames) {
94 if (attrNames.empty())
95 return;
96 p << " {";
97 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
98 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
99 p << '}';
100}
101
103 OpAsmParser &p,
105 SmallVectorImpl<Type> &resultTypes, UnitAttr &inferredResultTypes) {
106 if (failed(p.parseOptionalArrow()))
107 return success();
108
109 // Handle the case of inferred results.
110 if (succeeded(p.parseOptionalLess())) {
111 if (p.parseKeyword("inferred") || p.parseGreater())
112 return failure();
113 inferredResultTypes = p.getBuilder().getUnitAttr();
114 return success();
115 }
116
117 // Otherwise, parse the explicit results.
118 return failure(p.parseLParen() || p.parseOperandList(resultOperands) ||
119 p.parseColonTypeList(resultTypes) || p.parseRParen());
120}
121
122static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op,
123 OperandRange resultOperands,
124 TypeRange resultTypes,
125 UnitAttr inferredResultTypes) {
126 // Handle the case of inferred results.
127 if (inferredResultTypes) {
128 p << " -> <inferred>";
129 return;
130 }
131
132 // Otherwise, handle the explicit results.
133 if (!resultTypes.empty())
134 p << " -> (" << resultOperands << " : " << resultTypes << ")";
135}
136
137//===----------------------------------------------------------------------===//
138// pdl_interp::ForEachOp
139//===----------------------------------------------------------------------===//
140
141void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
142 Value range, Block *successor, bool initLoop) {
143 build(builder, state, range, successor);
144 if (initLoop) {
145 // Create the block and the loop variable.
146 // FIXME: Allow passing in a proper location for the loop variable.
147 auto rangeType = llvm::cast<pdl::RangeType>(range.getType());
148 state.regions.front()->emplaceBlock();
149 state.regions.front()->addArgument(rangeType.getElementType(),
150 state.location);
151 }
152}
153
154ParseResult ForEachOp::parse(OpAsmParser &parser, OperationState &result) {
155 // Parse the loop variable followed by type.
156 OpAsmParser::Argument loopVariable;
158 if (parser.parseArgument(loopVariable, /*allowType=*/true) ||
159 parser.parseKeyword("in", " after loop variable") ||
160 // Parse the operand (value range).
161 parser.parseOperand(operandInfo))
162 return failure();
163
164 // Resolve the operand.
165 Type rangeType = pdl::RangeType::get(loopVariable.type);
166 if (parser.resolveOperand(operandInfo, rangeType, result.operands))
167 return failure();
168
169 // Parse the body region.
170 Region *body = result.addRegion();
171 Block *successor;
172 if (parser.parseRegion(*body, loopVariable) ||
173 parser.parseOptionalAttrDict(result.attributes) ||
174 // Parse the successor.
175 parser.parseArrow() || parser.parseSuccessor(successor))
176 return failure();
177
178 result.addSuccessors(successor);
179 return success();
180}
181
182void ForEachOp::print(OpAsmPrinter &p) {
183 BlockArgument arg = getLoopVariable();
184 p << ' ' << arg << " : " << arg.getType() << " in " << getValues() << ' ';
185 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
186 p.printOptionalAttrDict((*this)->getAttrs());
187 p << " -> ";
188 p.printSuccessor(getSuccessor());
189}
190
191LogicalResult ForEachOp::verify() {
192 // Verify that the operation has exactly one argument.
193 if (getRegion().getNumArguments() != 1)
194 return emitOpError("requires exactly one argument");
195
196 // Verify that the loop variable and the operand (value range)
197 // have compatible types.
198 BlockArgument arg = getLoopVariable();
199 Type rangeType = pdl::RangeType::get(arg.getType());
200 if (rangeType != getValues().getType())
201 return emitOpError("operand must be a range of loop variable type");
202
203 return success();
204}
205
206//===----------------------------------------------------------------------===//
207// pdl_interp::FuncOp
208//===----------------------------------------------------------------------===//
209
210void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
211 FunctionType type, ArrayRef<NamedAttribute> attrs) {
212 buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs());
213}
214
215ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
216 auto buildFuncType =
217 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
219 std::string &) { return builder.getFunctionType(argTypes, results); };
220
222 parser, result, /*allowVariadic=*/false,
223 getFunctionTypeAttrName(result.name), buildFuncType,
224 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
225}
226
227void FuncOp::print(OpAsmPrinter &p) {
229 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
230 getArgAttrsAttrName(), getResAttrsAttrName());
231}
232
233//===----------------------------------------------------------------------===//
234// pdl_interp::GetValueTypeOp
235//===----------------------------------------------------------------------===//
236
237/// Given the result type of a `GetValueTypeOp`, return the expected input type.
239 Type valueTy = pdl::ValueType::get(type.getContext());
240 return llvm::isa<pdl::RangeType>(type) ? pdl::RangeType::get(valueTy)
241 : valueTy;
242}
243
244//===----------------------------------------------------------------------===//
245// pdl::CreateRangeOp
246//===----------------------------------------------------------------------===//
247
248static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
249 Type &resultType) {
250 // If arguments were provided, infer the result type from the argument list.
251 if (!argumentTypes.empty()) {
252 resultType =
253 pdl::RangeType::get(pdl::getRangeElementTypeOrSelf(argumentTypes[0]));
254 return success();
255 }
256 // Otherwise, parse the type as a trailing type.
257 return p.parseColonType(resultType);
258}
259
260static void printRangeType(OpAsmPrinter &p, CreateRangeOp op,
261 TypeRange argumentTypes, Type resultType) {
262 if (argumentTypes.empty())
263 p << ": " << resultType;
264}
265
266LogicalResult CreateRangeOp::verify() {
267 Type elementType = getType().getElementType();
268 for (Type operandType : getOperandTypes()) {
269 Type operandElementType = pdl::getRangeElementTypeOrSelf(operandType);
270 if (operandElementType != elementType) {
271 return emitOpError("expected operand to have element type ")
272 << elementType << ", but got " << operandElementType;
273 }
274 }
275 return success();
276}
277
278//===----------------------------------------------------------------------===//
279// pdl_interp::SwitchAttributeOp
280//===----------------------------------------------------------------------===//
281
282LogicalResult SwitchAttributeOp::verify() { return verifySwitchOp(*this); }
283
284//===----------------------------------------------------------------------===//
285// pdl_interp::SwitchOperandCountOp
286//===----------------------------------------------------------------------===//
287
288LogicalResult SwitchOperandCountOp::verify() { return verifySwitchOp(*this); }
289
290//===----------------------------------------------------------------------===//
291// pdl_interp::SwitchOperationNameOp
292//===----------------------------------------------------------------------===//
293
294LogicalResult SwitchOperationNameOp::verify() { return verifySwitchOp(*this); }
295
296//===----------------------------------------------------------------------===//
297// pdl_interp::SwitchResultCountOp
298//===----------------------------------------------------------------------===//
299
300LogicalResult SwitchResultCountOp::verify() { return verifySwitchOp(*this); }
301
302//===----------------------------------------------------------------------===//
303// pdl_interp::SwitchTypeOp
304//===----------------------------------------------------------------------===//
305
306LogicalResult SwitchTypeOp::verify() { return verifySwitchOp(*this); }
307
308//===----------------------------------------------------------------------===//
309// pdl_interp::SwitchTypesOp
310//===----------------------------------------------------------------------===//
311
312LogicalResult SwitchTypesOp::verify() { return verifySwitchOp(*this); }
313
314//===----------------------------------------------------------------------===//
315// TableGen Auto-Generated Op and Interface Definitions
316//===----------------------------------------------------------------------===//
317
318#define GET_OP_CLASSES
319#include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
ArrayAttr()
b getContext())
static void printRangeType(OpAsmPrinter &p, CreateRangeOp op, TypeRange argumentTypes, Type resultType)
static void printCreateOperationOpResults(OpAsmPrinter &p, CreateOperationOp op, OperandRange resultOperands, TypeRange resultTypes, UnitAttr inferredResultTypes)
static Type getGetValueTypeOpValueType(Type type)
Given the result type of a GetValueTypeOp, return the expected input type.
static ParseResult parseCreateOperationOpResults(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &resultOperands, SmallVectorImpl< Type > &resultTypes, UnitAttr &inferredResultTypes)
static ParseResult parseCreateOperationOpAttributes(OpAsmParser &p, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &attrOperands, ArrayAttr &attrNamesAttr)
Definition PDLInterp.cpp:66
static LogicalResult verifySwitchOp(OpT op)
Definition PDLInterp.cpp:31
static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes, Type &resultType)
static void printCreateOperationOpAttributes(OpAsmPrinter &p, CreateOperationOp op, OperandRange attrArgs, ArrayAttr attrNames)
Definition PDLInterp.cpp:90
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:98
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printSuccessor(Block *successor)=0
Print the given successor.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:207
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
A named class for passing around the variadic flag.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Type getRangeElementTypeOrSelf(Type type)
If the given type is a range, return its element type, otherwise return the type itself.
Definition PDLTypes.cpp:62
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.