MLIR 22.0.0git
SMTAttributes.cpp
Go to the documentation of this file.
1//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Builders.h"
14#include "llvm/ADT/TypeSwitch.h"
15
16using namespace mlir;
17using namespace mlir::smt;
18
19//===----------------------------------------------------------------------===//
20// BitVectorAttr
21//===----------------------------------------------------------------------===//
22
23LogicalResult BitVectorAttr::verify(
25 APInt value) { // NOLINT(performance-unnecessary-value-param)
26 if (value.getBitWidth() < 1)
27 return emitError() << "bit-width must be at least 1, but got "
28 << value.getBitWidth();
29 return success();
30}
31
32std::string BitVectorAttr::getValueAsString(bool prefix) const {
33 unsigned width = getValue().getBitWidth();
34 SmallVector<char> toPrint;
35 StringRef pref = prefix ? "#" : "";
36 if (width % 4 == 0) {
37 getValue().toString(toPrint, 16, false, false, false);
38 // APInt's 'toString' omits leading zeros. However, those are critical here
39 // because they determine the bit-width of the bit-vector.
40 SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0');
41 return (pref + "x" + Twine(leadingZeros) + toPrint).str();
42 }
43
44 getValue().toString(toPrint, 2, false, false, false);
45 // APInt's 'toString' omits leading zeros
46 SmallVector<char> leadingZeros(width - toPrint.size(), '0');
47 return (pref + "b" + Twine(leadingZeros) + toPrint).str();
48}
49
50/// Parse an SMT-LIB formatted bit-vector string.
51static FailureOr<APInt>
53 StringRef value) {
54 if (value[0] != '#')
55 return emitError() << "expected '#'";
56
57 if (value.size() < 3)
58 return emitError() << "expected at least one digit";
59
60 if (value[1] == 'b')
61 return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()),
62 2);
63
64 if (value[1] == 'x')
65 return APInt((value.size() - 2) * 4,
66 std::string(value.begin() + 2, value.end()), 16);
67
68 return emitError() << "expected either 'b' or 'x'";
69}
70
71BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
72 auto maybeValue = parseBitVectorString(nullptr, value);
73
74 assert(succeeded(maybeValue) && "string must have SMT-LIB format");
75 return Base::get(context, *maybeValue);
76}
77
78BitVectorAttr
79BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
80 MLIRContext *context, StringRef value) {
81 auto maybeValue = parseBitVectorString(emitError, value);
82 if (failed(maybeValue))
83 return {};
84
85 return Base::getChecked(emitError, context, *maybeValue);
86}
87
88BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value,
89 unsigned width) {
90 return Base::get(context, APInt(width, value));
91}
92
93BitVectorAttr
94BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
95 MLIRContext *context, uint64_t value,
96 unsigned width) {
97 if (width < 64 && value >= (UINT64_C(1) << width)) {
98 emitError() << "value does not fit in a bit-vector of desired width";
99 return {};
100 }
101 return Base::getChecked(emitError, context, APInt(width, value));
102}
103
104Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
105 llvm::SMLoc loc = odsParser.getCurrentLocation();
106
107 APInt val;
108 if (odsParser.parseLess() || odsParser.parseInteger(val) ||
109 odsParser.parseGreater())
110 return {};
111
112 // Requires the use of `quantified(<attr>)` in operation assembly formats.
113 if (!odsType || !llvm::isa<BitVectorType>(odsType)) {
114 odsParser.emitError(loc) << "explicit bit-vector type required";
115 return {};
116 }
117
118 unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
119
120 if (width > val.getBitWidth()) {
121 // sext is always safe here, even for unsigned values, because the
122 // parseOptionalInteger method will return something with a zero in the
123 // top bits if it is a positive number.
124 val = val.sext(width);
125 } else if (width < val.getBitWidth()) {
126 // The parser can return an unnecessarily wide result.
127 // This isn't a problem, but truncating off bits is bad.
128 unsigned neededBits =
129 val.isNegative() ? val.getSignificantBits() : val.getActiveBits();
130 if (width < neededBits) {
131 odsParser.emitError(loc)
132 << "integer value out of range for given bit-vector type " << odsType;
133 return {};
134 }
135 val = val.trunc(width);
136 }
137
138 return BitVectorAttr::get(odsParser.getContext(), val);
139}
140
141void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
142 // This printer only works for the extended format where the MLIR
143 // infrastructure prints the type for us. This means, the attribute should
144 // never be used without `quantified` in an assembly format.
145 odsPrinter << "<" << getValue() << ">";
146}
147
148Type BitVectorAttr::getType() const {
149 return BitVectorType::get(getContext(), getValue().getBitWidth());
150}
151
152//===----------------------------------------------------------------------===//
153// ODS Boilerplate
154//===----------------------------------------------------------------------===//
155
156#define GET_ATTRDEF_CLASSES
157#include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc"
158
159void SMTDialect::registerAttributes() {
160 addAttributes<
161#define GET_ATTRDEF_LIST
162#include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc"
163 >();
164}
return success()
static unsigned getBitWidth(Type type)
Definition Pattern.cpp:378
b getContext())
static FailureOr< APInt > parseBitVectorString(function_ref< InFlightDiagnostic()> emitError, StringRef value)
Parse an SMT-LIB formatted bit-vector string.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseGreater()=0
Parse a '>' token.
This base class exposes generic asm printer hooks, usable across the various derived printers.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152