MLIR  19.0.0git
PolynomialAttributes.cpp
Go to the documentation of this file.
1 //===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 
11 #include "mlir/Support/LLVM.h"
13 #include "llvm/ADT/StringExtras.h"
14 #include "llvm/ADT/StringRef.h"
15 #include "llvm/ADT/StringSet.h"
16 
17 namespace mlir {
18 namespace polynomial {
19 
20 void PolynomialAttr::print(AsmPrinter &p) const {
21  p << '<';
22  p << getPolynomial();
23  p << '>';
24 }
25 
26 /// Try to parse a monomial. If successful, populate the fields of the outparam
27 /// `monomial` with the results, and the `variable` outparam with the parsed
28 /// variable name. Sets shouldParseMore to true if the monomial is followed by
29 /// a '+'.
30 ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
31  llvm::StringRef &variable, bool &isConstantTerm,
32  bool &shouldParseMore) {
33  APInt parsedCoeff(apintBitWidth, 1);
34  auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff);
35  monomial.coefficient = parsedCoeff;
36 
37  isConstantTerm = false;
38  shouldParseMore = false;
39 
40  // A + indicates it's a constant term with more to go, as in `1 + x`.
41  if (succeeded(parser.parseOptionalPlus())) {
42  // If no coefficient was parsed, and there's a +, then it's effectively
43  // parsing an empty string.
44  if (!parsedCoeffResult.has_value()) {
45  return failure();
46  }
47  monomial.exponent = APInt(apintBitWidth, 0);
48  isConstantTerm = true;
49  shouldParseMore = true;
50  return success();
51  }
52 
53  // A monomial can be a trailing constant term, as in `x + 1`.
54  if (failed(parser.parseOptionalKeyword(&variable))) {
55  // If neither a coefficient nor a variable was found, then it's effectively
56  // parsing an empty string.
57  if (!parsedCoeffResult.has_value()) {
58  return failure();
59  }
60 
61  monomial.exponent = APInt(apintBitWidth, 0);
62  isConstantTerm = true;
63  return success();
64  }
65 
66  // Parse exponentiation symbol as `**`. We can't use caret because it's
67  // reserved for basic block identifiers If no star is present, it's treated
68  // as a polynomial with exponent 1.
69  if (succeeded(parser.parseOptionalStar())) {
70  // If there's one * there must be two.
71  if (failed(parser.parseStar())) {
72  return failure();
73  }
74 
75  // If there's a **, then the integer exponent is required.
76  APInt parsedExponent(apintBitWidth, 0);
77  if (failed(parser.parseInteger(parsedExponent))) {
78  parser.emitError(parser.getCurrentLocation(),
79  "found invalid integer exponent");
80  return failure();
81  }
82 
83  monomial.exponent = parsedExponent;
84  } else {
85  monomial.exponent = APInt(apintBitWidth, 1);
86  }
87 
88  if (succeeded(parser.parseOptionalPlus())) {
89  shouldParseMore = true;
90  }
91  return success();
92 }
93 
94 Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
95  if (failed(parser.parseLess()))
96  return {};
97 
99  llvm::StringSet<> variables;
100 
101  while (true) {
102  Monomial parsedMonomial;
103  llvm::StringRef parsedVariableRef;
104  bool isConstantTerm;
105  bool shouldParseMore;
106  if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef,
107  isConstantTerm, shouldParseMore))) {
108  parser.emitError(parser.getCurrentLocation(), "expected a monomial");
109  return {};
110  }
111 
112  if (!isConstantTerm) {
113  std::string parsedVariable = parsedVariableRef.str();
114  variables.insert(parsedVariable);
115  }
116  monomials.push_back(parsedMonomial);
117 
118  if (shouldParseMore)
119  continue;
120 
121  if (succeeded(parser.parseOptionalGreater())) {
122  break;
123  }
124  parser.emitError(
125  parser.getCurrentLocation(),
126  "expected + and more monomials, or > to end polynomial attribute");
127  return {};
128  }
129 
130  if (variables.size() > 1) {
131  std::string vars = llvm::join(variables.keys(), ", ");
132  parser.emitError(
133  parser.getCurrentLocation(),
134  "polynomials must have one indeterminate, but there were multiple: " +
135  vars);
136  }
137 
138  auto result = Polynomial::fromMonomials(monomials);
139  if (failed(result)) {
140  parser.emitError(parser.getCurrentLocation())
141  << "parsed polynomial must have unique exponents among monomials";
142  return {};
143  }
144  return PolynomialAttr::get(parser.getContext(), result.value());
145 }
146 
147 void RingAttr::print(AsmPrinter &p) const {
148  p << "#polynomial.ring<coefficientType=" << getCoefficientType()
149  << ", coefficientModulus=" << getCoefficientModulus()
150  << ", polynomialModulus=" << getPolynomialModulus() << '>';
151 }
152 
153 Attribute RingAttr::parse(AsmParser &parser, Type type) {
154  if (failed(parser.parseLess()))
155  return {};
156 
157  if (failed(parser.parseKeyword("coefficientType")))
158  return {};
159 
160  if (failed(parser.parseEqual()))
161  return {};
162 
163  Type ty;
164  if (failed(parser.parseType(ty)))
165  return {};
166 
167  if (failed(parser.parseComma()))
168  return {};
169 
170  IntegerAttr coefficientModulusAttr = nullptr;
171  if (succeeded(parser.parseKeyword("coefficientModulus"))) {
172  if (failed(parser.parseEqual()))
173  return {};
174 
175  IntegerType iType = mlir::dyn_cast<IntegerType>(ty);
176  if (!iType) {
177  parser.emitError(parser.getCurrentLocation(),
178  "coefficientType must specify an integer type");
179  return {};
180  }
181  APInt coefficientModulus(iType.getWidth(), 0);
182  auto result = parser.parseInteger(coefficientModulus);
183  if (failed(result)) {
184  parser.emitError(parser.getCurrentLocation(),
185  "invalid coefficient modulus");
186  return {};
187  }
188  coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus);
189 
190  if (failed(parser.parseComma()))
191  return {};
192  }
193 
194  PolynomialAttr polyAttr = nullptr;
195  if (succeeded(parser.parseKeyword("polynomialModulus"))) {
196  if (failed(parser.parseEqual()))
197  return {};
198 
199  PolynomialAttr attr;
200  if (failed(parser.parseAttribute<PolynomialAttr>(attr)))
201  return {};
202  polyAttr = attr;
203  }
204 
205  Polynomial poly = polyAttr.getPolynomial();
206  APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
207  IntegerAttr rootAttr = nullptr;
208  if (succeeded(parser.parseOptionalComma())) {
209  if (failed(parser.parseKeyword("primitiveRoot")) ||
210  failed(parser.parseEqual()))
211  return {};
212 
213  ParseResult result = parser.parseInteger(root);
214  if (failed(result)) {
215  parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
216  return {};
217  }
218  rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
219  }
220 
221  if (failed(parser.parseGreater()))
222  return {};
223 
224  return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
225  polyAttr, rootAttr);
226 }
227 
228 } // namespace polynomial
229 } // namespace mlir
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static FailureOr< Polynomial > fromMonomials(ArrayRef< Monomial > monomials)
Definition: Polynomial.cpp:21
@ Type
An inlay hint that for a type annotation.
ParseResult parseMonomial(AsmParser &parser, Monomial &monomial, llvm::StringRef &variable, bool &isConstantTerm, bool &shouldParseMore)
Try to parse a monomial.
constexpr unsigned apintBitWidth
This restricts statically defined polynomials to have at most 64-bit coefficients.
Definition: Polynomial.h:28
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72