MLIR  21.0.0git
SMTAttributes.cpp
Go to the documentation of this file.
1 //===- SMTAttributes.cpp - Implement SMT attributes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 #include "llvm/Support/Format.h"
16 
17 using namespace mlir;
18 using namespace mlir::smt;
19 
20 //===----------------------------------------------------------------------===//
21 // BitVectorAttr
22 //===----------------------------------------------------------------------===//
23 
24 LogicalResult BitVectorAttr::verify(
26  APInt value) { // NOLINT(performance-unnecessary-value-param)
27  if (value.getBitWidth() < 1)
28  return emitError() << "bit-width must be at least 1, but got "
29  << value.getBitWidth();
30  return success();
31 }
32 
33 std::string BitVectorAttr::getValueAsString(bool prefix) const {
34  unsigned width = getValue().getBitWidth();
35  SmallVector<char> toPrint;
36  StringRef pref = prefix ? "#" : "";
37  if (width % 4 == 0) {
38  getValue().toString(toPrint, 16, false, false, false);
39  // APInt's 'toString' omits leading zeros. However, those are critical here
40  // because they determine the bit-width of the bit-vector.
41  SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0');
42  return (pref + "x" + Twine(leadingZeros) + toPrint).str();
43  }
44 
45  getValue().toString(toPrint, 2, false, false, false);
46  // APInt's 'toString' omits leading zeros
47  SmallVector<char> leadingZeros(width - toPrint.size(), '0');
48  return (pref + "b" + Twine(leadingZeros) + toPrint).str();
49 }
50 
51 /// Parse an SMT-LIB formatted bit-vector string.
52 static FailureOr<APInt>
54  StringRef value) {
55  if (value[0] != '#')
56  return emitError() << "expected '#'";
57 
58  if (value.size() < 3)
59  return emitError() << "expected at least one digit";
60 
61  if (value[1] == 'b')
62  return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()),
63  2);
64 
65  if (value[1] == 'x')
66  return APInt((value.size() - 2) * 4,
67  std::string(value.begin() + 2, value.end()), 16);
68 
69  return emitError() << "expected either 'b' or 'x'";
70 }
71 
72 BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
73  auto maybeValue = parseBitVectorString(nullptr, value);
74 
75  assert(succeeded(maybeValue) && "string must have SMT-LIB format");
76  return Base::get(context, *maybeValue);
77 }
78 
79 BitVectorAttr
80 BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
81  MLIRContext *context, StringRef value) {
82  auto maybeValue = parseBitVectorString(emitError, value);
83  if (failed(maybeValue))
84  return {};
85 
86  return Base::getChecked(emitError, context, *maybeValue);
87 }
88 
89 BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value,
90  unsigned width) {
91  return Base::get(context, APInt(width, value));
92 }
93 
94 BitVectorAttr
95 BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
96  MLIRContext *context, uint64_t value,
97  unsigned width) {
98  if (width < 64 && value >= (UINT64_C(1) << width)) {
99  emitError() << "value does not fit in a bit-vector of desired width";
100  return {};
101  }
102  return Base::getChecked(emitError, context, APInt(width, value));
103 }
104 
105 Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
106  llvm::SMLoc loc = odsParser.getCurrentLocation();
107 
108  APInt val;
109  if (odsParser.parseLess() || odsParser.parseInteger(val) ||
110  odsParser.parseGreater())
111  return {};
112 
113  // Requires the use of `quantified(<attr>)` in operation assembly formats.
114  if (!odsType || !llvm::isa<BitVectorType>(odsType)) {
115  odsParser.emitError(loc) << "explicit bit-vector type required";
116  return {};
117  }
118 
119  unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
120 
121  if (width > val.getBitWidth()) {
122  // sext is always safe here, even for unsigned values, because the
123  // parseOptionalInteger method will return something with a zero in the
124  // top bits if it is a positive number.
125  val = val.sext(width);
126  } else if (width < val.getBitWidth()) {
127  // The parser can return an unnecessarily wide result.
128  // This isn't a problem, but truncating off bits is bad.
129  unsigned neededBits =
130  val.isNegative() ? val.getSignificantBits() : val.getActiveBits();
131  if (width < neededBits) {
132  odsParser.emitError(loc)
133  << "integer value out of range for given bit-vector type " << odsType;
134  return {};
135  }
136  val = val.trunc(width);
137  }
138 
139  return BitVectorAttr::get(odsParser.getContext(), val);
140 }
141 
142 void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
143  // This printer only works for the extended format where the MLIR
144  // infrastructure prints the type for us. This means, the attribute should
145  // never be used without `quantified` in an assembly format.
146  odsPrinter << "<" << getValue() << ">";
147 }
148 
150  return BitVectorType::get(getContext(), getValue().getBitWidth());
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // ODS Boilerplate
155 //===----------------------------------------------------------------------===//
156 
157 #define GET_ATTRDEF_CLASSES
158 #include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc"
159 
160 void SMTDialect::registerAttributes() {
161  addAttributes<
162 #define GET_ATTRDEF_LIST
163 #include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc"
164  >();
165 }
static StringRef getValueAsString(const Init *init)
Definition: Attribute.cpp:28
static unsigned getBitWidth(Type type)
Definition: Pattern.cpp:385
static MLIRContext * getContext(OpFoldResult val)
static FailureOr< APInt > parseBitVectorString(function_ref< InFlightDiagnostic()> emitError, StringRef value)
Parse an SMT-LIB formatted bit-vector string.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseGreater()=0
Parse a '>' token.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424