MLIR  22.0.0git
SMTAttributes.cpp
Go to the documentation of this file.
1 //===- SMTAttributes.cpp - Implement SMT attributes -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/IR/Builders.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::smt;
18 
19 //===----------------------------------------------------------------------===//
20 // BitVectorAttr
21 //===----------------------------------------------------------------------===//
22 
23 LogicalResult BitVectorAttr::verify(
25  APInt value) { // NOLINT(performance-unnecessary-value-param)
26  if (value.getBitWidth() < 1)
27  return emitError() << "bit-width must be at least 1, but got "
28  << value.getBitWidth();
29  return success();
30 }
31 
32 std::string BitVectorAttr::getValueAsString(bool prefix) const {
33  unsigned width = getValue().getBitWidth();
34  SmallVector<char> toPrint;
35  StringRef pref = prefix ? "#" : "";
36  if (width % 4 == 0) {
37  getValue().toString(toPrint, 16, false, false, false);
38  // APInt's 'toString' omits leading zeros. However, those are critical here
39  // because they determine the bit-width of the bit-vector.
40  SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0');
41  return (pref + "x" + Twine(leadingZeros) + toPrint).str();
42  }
43 
44  getValue().toString(toPrint, 2, false, false, false);
45  // APInt's 'toString' omits leading zeros
46  SmallVector<char> leadingZeros(width - toPrint.size(), '0');
47  return (pref + "b" + Twine(leadingZeros) + toPrint).str();
48 }
49 
50 /// Parse an SMT-LIB formatted bit-vector string.
51 static FailureOr<APInt>
53  StringRef value) {
54  if (value[0] != '#')
55  return emitError() << "expected '#'";
56 
57  if (value.size() < 3)
58  return emitError() << "expected at least one digit";
59 
60  if (value[1] == 'b')
61  return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()),
62  2);
63 
64  if (value[1] == 'x')
65  return APInt((value.size() - 2) * 4,
66  std::string(value.begin() + 2, value.end()), 16);
67 
68  return emitError() << "expected either 'b' or 'x'";
69 }
70 
71 BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
72  auto maybeValue = parseBitVectorString(nullptr, value);
73 
74  assert(succeeded(maybeValue) && "string must have SMT-LIB format");
75  return Base::get(context, *maybeValue);
76 }
77 
78 BitVectorAttr
79 BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
80  MLIRContext *context, StringRef value) {
81  auto maybeValue = parseBitVectorString(emitError, value);
82  if (failed(maybeValue))
83  return {};
84 
85  return Base::getChecked(emitError, context, *maybeValue);
86 }
87 
88 BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value,
89  unsigned width) {
90  return Base::get(context, APInt(width, value));
91 }
92 
93 BitVectorAttr
94 BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
95  MLIRContext *context, uint64_t value,
96  unsigned width) {
97  if (width < 64 && value >= (UINT64_C(1) << width)) {
98  emitError() << "value does not fit in a bit-vector of desired width";
99  return {};
100  }
101  return Base::getChecked(emitError, context, APInt(width, value));
102 }
103 
104 Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
105  llvm::SMLoc loc = odsParser.getCurrentLocation();
106 
107  APInt val;
108  if (odsParser.parseLess() || odsParser.parseInteger(val) ||
109  odsParser.parseGreater())
110  return {};
111 
112  // Requires the use of `quantified(<attr>)` in operation assembly formats.
113  if (!odsType || !llvm::isa<BitVectorType>(odsType)) {
114  odsParser.emitError(loc) << "explicit bit-vector type required";
115  return {};
116  }
117 
118  unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
119 
120  if (width > val.getBitWidth()) {
121  // sext is always safe here, even for unsigned values, because the
122  // parseOptionalInteger method will return something with a zero in the
123  // top bits if it is a positive number.
124  val = val.sext(width);
125  } else if (width < val.getBitWidth()) {
126  // The parser can return an unnecessarily wide result.
127  // This isn't a problem, but truncating off bits is bad.
128  unsigned neededBits =
129  val.isNegative() ? val.getSignificantBits() : val.getActiveBits();
130  if (width < neededBits) {
131  odsParser.emitError(loc)
132  << "integer value out of range for given bit-vector type " << odsType;
133  return {};
134  }
135  val = val.trunc(width);
136  }
137 
138  return BitVectorAttr::get(odsParser.getContext(), val);
139 }
140 
141 void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
142  // This printer only works for the extended format where the MLIR
143  // infrastructure prints the type for us. This means, the attribute should
144  // never be used without `quantified` in an assembly format.
145  odsPrinter << "<" << getValue() << ">";
146 }
147 
149  return BitVectorType::get(getContext(), getValue().getBitWidth());
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // ODS Boilerplate
154 //===----------------------------------------------------------------------===//
155 
156 #define GET_ATTRDEF_CLASSES
157 #include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc"
158 
159 void SMTDialect::registerAttributes() {
160  addAttributes<
161 #define GET_ATTRDEF_LIST
162 #include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc"
163  >();
164 }
static StringRef getValueAsString(const Init *init)
Definition: Attribute.cpp:27
static unsigned getBitWidth(Type type)
Definition: Pattern.cpp:385
static MLIRContext * getContext(OpFoldResult val)
static FailureOr< APInt > parseBitVectorString(function_ref< InFlightDiagnostic()> emitError, StringRef value)
Parse an SMT-LIB formatted bit-vector string.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseGreater()=0
Parse a '>' token.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423