MLIR  20.0.0git
MatchInterfaces.cpp
Go to the documentation of this file.
1 //===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // Printing and parsing for match ops.
15 //===----------------------------------------------------------------------===//
16 
17 /// Keyword syntax for positional specification inversion.
18 constexpr const static llvm::StringLiteral kDimExceptKeyword = "except";
19 
20 /// Keyword syntax for full inclusion in positional specification.
21 constexpr const static llvm::StringLiteral kDimAllKeyword = "all";
22 
24  DenseI64ArrayAttr &rawDimList,
25  UnitAttr &isInverted,
26  UnitAttr &isAll) {
27  Builder &builder = parser.getBuilder();
28  if (parser.parseOptionalKeyword(kDimAllKeyword).succeeded()) {
29  rawDimList = builder.getDenseI64ArrayAttr({});
30  isInverted = nullptr;
31  isAll = builder.getUnitAttr();
32  return success();
33  }
34 
35  isAll = nullptr;
36  isInverted = nullptr;
37  if (parser.parseOptionalKeyword(kDimExceptKeyword).succeeded()) {
38  isInverted = builder.getUnitAttr();
39  }
40 
41  if (isInverted) {
42  if (parser.parseLParen().failed())
43  return failure();
44  }
45 
46  SmallVector<int64_t> values;
47  ParseResult listResult = parser.parseCommaSeparatedList(
48  [&]() { return parser.parseInteger(values.emplace_back()); });
49  if (listResult.failed())
50  return failure();
51 
52  rawDimList = builder.getDenseI64ArrayAttr(values);
53 
54  if (isInverted) {
55  if (parser.parseRParen().failed())
56  return failure();
57  }
58  return success();
59 }
60 
62  DenseI64ArrayAttr rawDimList,
63  UnitAttr isInverted, UnitAttr isAll) {
64  if (isAll) {
65  printer << kDimAllKeyword;
66  return;
67  }
68  if (isInverted) {
69  printer << kDimExceptKeyword << "(";
70  }
71  llvm::interleaveComma(rawDimList.asArrayRef(), printer.getStream(),
72  [&](int64_t value) { printer << value; });
73  if (isInverted) {
74  printer << ")";
75  }
76 }
77 
80  bool inverted, bool all) {
81  if (all) {
82  if (inverted) {
83  return op->emitOpError()
84  << "cannot request both 'all' and 'inverted' values in the list";
85  }
86  if (!raw.empty()) {
87  return op->emitOpError()
88  << "cannot both request 'all' and specific values in the list";
89  }
90  }
91  if (!all && raw.empty()) {
92  return op->emitOpError() << "must request specific values in the list if "
93  "'all' is not specified";
94  }
95  SmallVector<int64_t> rawVector = llvm::to_vector(raw);
96  auto *it = llvm::unique(rawVector);
97  if (it != rawVector.end())
98  return op->emitOpError() << "expected the listed values to be unique";
99 
100  return success();
101 }
102 
104  Location loc, bool isAll, bool isInverted, ArrayRef<int64_t> rawList,
105  int64_t maxNumber, SmallVectorImpl<int64_t> &result) {
106  assert(maxNumber > 0 && "expected size to be positive");
107  assert(!(isAll && isInverted) && "cannot invert all");
108  if (isAll) {
109  result = llvm::to_vector(llvm::seq<int64_t>(0, maxNumber));
111  }
112 
113  SmallVector<int64_t> expanded;
114  llvm::SmallDenseSet<int64_t> visited;
115  expanded.reserve(rawList.size());
116  SmallVectorImpl<int64_t> &target = isInverted ? expanded : result;
117  for (int64_t raw : rawList) {
118  int64_t updated = raw < 0 ? maxNumber + raw : raw;
119  if (updated >= maxNumber) {
120  return emitSilenceableFailure(loc)
121  << "position overflow " << updated << " (updated from " << raw
122  << ") for maximum " << maxNumber;
123  }
124  if (updated < 0) {
125  return emitSilenceableFailure(loc) << "position underflow " << updated
126  << " (updated from " << raw << ")";
127  }
128  if (!visited.insert(updated).second) {
129  return emitSilenceableFailure(loc) << "repeated position " << updated
130  << " (updated from " << raw << ")";
131  }
132  target.push_back(updated);
133  }
134 
135  if (!isInverted)
137 
138  result.reserve(result.size() + (maxNumber - expanded.size()));
139  for (int64_t candidate : llvm::seq<int64_t>(0, maxNumber)) {
140  if (llvm::is_contained(expanded, candidate))
141  continue;
142  result.push_back(candidate);
143  }
144 
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // Generated interface implementation.
150 //===----------------------------------------------------------------------===//
151 
152 #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.cpp.inc"
constexpr static const llvm::StringLiteral kDimAllKeyword
Keyword syntax for full inclusion in positional specification.
constexpr static const llvm::StringLiteral kDimExceptKeyword
Keyword syntax for positional specification inversion.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:138
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void printTransformMatchDims(OpAsmPrinter &printer, Operation *op, DenseI64ArrayAttr rawDimList, UnitAttr isInverted, UnitAttr isAll)
Prints a positional index specification for transform match operations.
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
ParseResult parseTransformMatchDims(OpAsmParser &parser, DenseI64ArrayAttr &rawDimList, UnitAttr &isInverted, UnitAttr &isAll)
Parses a positional index specification for transform match operations.
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.