12 #include "llvm/ADT/APSInt.h"
22 LogicalResult BVConstantOp::inferReturnTypes(
27 inferredReturnTypes.push_back(
28 properties.
as<Properties *>()->getValue().getType());
32 void BVConstantOp::getAsmResultNames(
35 llvm::raw_svector_ostream specialName(specialNameBuffer);
36 specialName <<
"c" << getValue().getValue() <<
"_bv"
37 << getValue().getValue().getBitWidth();
38 setNameFn(getResult(), specialName.str());
42 assert(adaptor.getOperands().empty() &&
"constant has no operands");
43 return getValueAttr();
50 void DeclareFunOp::getAsmResultNames(
52 setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() :
"");
59 LogicalResult SolverOp::verifyRegions() {
60 if (getBody()->getTerminator()->getOperands().getTypes() != getResultTypes())
61 return emitOpError() <<
"types of yielded values must match return values";
62 if (getBody()->getArgumentTypes() != getInputs().getTypes())
64 <<
"block argument types must match the types of the 'inputs'";
73 LogicalResult CheckOp::verifyRegions() {
74 if (getSatRegion().front().getTerminator()->getOperands().getTypes() !=
76 return emitOpError() <<
"types of yielded values in 'sat' region must "
77 "match return values";
78 if (getUnknownRegion().front().getTerminator()->getOperands().getTypes() !=
80 return emitOpError() <<
"types of yielded values in 'unknown' region must "
81 "match return values";
82 if (getUnsatRegion().front().getTerminator()->getOperands().getTypes() !=
84 return emitOpError() <<
"types of yielded values in 'unsat' region must "
85 "match return values";
119 printer <<
' ' << getInputs();
121 printer <<
" : " << getInputs().front().getType();
125 if (getInputs().size() < 2)
126 return emitOpError() <<
"'inputs' must have at least size 2, but got "
127 << getInputs().size();
141 printer <<
' ' << getInputs();
143 printer <<
" : " << getInputs().front().getType();
147 if (getInputs().size() < 2)
148 return emitOpError() <<
"'inputs' must have at least size 2, but got "
149 << getInputs().size();
159 unsigned rangeWidth =
getType().getWidth();
160 unsigned inputWidth = cast<BitVectorType>(getInput().
getType()).getWidth();
161 if (getLowBit() + rangeWidth > inputWidth)
162 return emitOpError(
"range to be extracted is too big, expected range "
163 "starting at index ")
164 << getLowBit() <<
" of length " << rangeWidth
165 <<
" requires input width of at least " << (getLowBit() + rangeWidth)
166 <<
", but the input width is only " << inputWidth;
174 LogicalResult ConcatOp::inferReturnTypes(
179 context, cast<BitVectorType>(operands[0].
getType()).getWidth() +
180 cast<BitVectorType>(operands[1].
getType()).getWidth()));
189 unsigned inputWidth = cast<BitVectorType>(getInput().
getType()).getWidth();
190 unsigned resultWidth =
getType().getWidth();
191 if (resultWidth % inputWidth != 0)
192 return emitOpError() <<
"result bit-vector width must be a multiple of the "
193 "input bit-vector width";
198 unsigned RepeatOp::getCount() {
199 unsigned inputWidth = cast<BitVectorType>(getInput().
getType()).getWidth();
200 unsigned resultWidth =
getType().getWidth();
201 return resultWidth / inputWidth;
206 unsigned inputWidth = cast<BitVectorType>(input.
getType()).getWidth();
208 build(builder, state, resultTy, input);
220 if (count.isNonPositive())
221 return parser.
emitError(countLoc) <<
"integer must be positive";
232 auto bvInputTy = dyn_cast<BitVectorType>(inputType);
234 return parser.
emitError(inputLoc) <<
"input must have bit-vector type";
238 const unsigned maxBw = 63;
239 if (count.getActiveBits() > maxBw)
241 <<
"integer must fit into " << maxBw <<
" bits";
246 APInt resultBw = bvInputTy.getWidth() * count.zext(2 * maxBw);
247 if (resultBw.getActiveBits() > maxBw)
249 <<
"result bit-width (provided integer times bit-width of the input "
250 "type) must fit into "
260 printer <<
" " << getCount() <<
" times " << getInput();
262 printer <<
" : " << getInput().getType();
269 void BoolConstantOp::getAsmResultNames(
271 setNameFn(getResult(), getValue() ?
"true" :
"false");
274 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
275 assert(adaptor.getOperands().empty() &&
"constant has no operands");
276 return getValueAttr();
283 void IntConstantOp::getAsmResultNames(
286 llvm::raw_svector_ostream specialName(specialNameBuffer);
287 specialName <<
"c" << getValue();
288 setNameFn(getResult(), specialName.str());
292 assert(adaptor.getOperands().empty() &&
"constant has no operands");
293 return getValueAttr();
297 p <<
" " << getValue();
320 template <
typename QuantifierOp>
322 if (op.getBoundVarNames() &&
323 op.getBody().getNumArguments() != op.getBoundVarNames()->size())
324 return op.emitOpError(
325 "number of bound variable names must match number of block arguments");
327 return op.emitOpError()
328 <<
"bound variables must by any non-function SMT value";
330 if (op.getBody().front().getTerminator()->getNumOperands() != 1)
331 return op.emitOpError(
"must have exactly one yielded value");
333 op.getBody().front().getTerminator()->getOperand(0).getType()))
334 return op.emitOpError(
"yielded value must be of '!smt.bool' type");
337 unsigned i = regionWithIndex.index();
338 Region ®ion = regionWithIndex.value();
341 return op.emitOpError()
342 <<
"block argument number and types of the 'body' "
343 "and 'patterns' region #"
344 << i <<
" must match";
346 return op.emitOpError() <<
"'patterns' region #" << i
347 <<
" must have at least one yielded value";
351 if (!isa<SMTDialect>(childOp->
getDialect())) {
352 auto diag = op.emitOpError()
353 <<
"the 'patterns' region #" << i
354 <<
" may only contain SMT dialect operations";
355 diag.attachNote(childOp->getLoc()) <<
"first non-SMT operation here";
356 return WalkResult::interrupt();
361 if (isa<ForallOp, ExistsOp>(childOp)) {
362 auto diag = op.emitOpError() <<
"the 'patterns' region #" << i
363 <<
" must not contain "
364 "any variable binding operations";
365 diag.attachNote(childOp->
getLoc()) <<
"first violating operation here";
371 if (result.wasInterrupted())
378 template <
typename Properties>
384 uint32_t weight,
bool noPattern) {
392 if (boundVarNames.has_value()) {
394 for (StringRef str : *boundVarNames)
395 boundVarNamesList.emplace_back(odsBuilder.
getStringAttr(str));
410 if (patternBuilder) {
424 if (!getPatterns().empty() && getNoPattern())
425 return emitOpError() <<
"patterns and the no_pattern attribute must not be "
426 "specified at the same time";
431 LogicalResult ForallOp::verifyRegions() {
435 void ForallOp::build(
440 uint32_t weight,
bool noPattern) {
441 buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
442 boundVarNames, patternBuilder, weight, noPattern);
450 if (!getPatterns().empty() && getNoPattern())
451 return emitOpError() <<
"patterns and the no_pattern attribute must not be "
452 "specified at the same time";
457 LogicalResult ExistsOp::verifyRegions() {
461 void ExistsOp::build(
466 uint32_t weight,
bool noPattern) {
467 buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder,
468 boundVarNames, patternBuilder, weight, noPattern);
471 #define GET_OP_CLASSES
472 #include "mlir/Dialect/SMT/IR/SMT.cpp.inc"
static std::string diag(const llvm::Value &value)
static void buildQuantifier(OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, function_ref< Value(OpBuilder &, Location, ValueRange)> bodyBuilder, std::optional< ArrayRef< StringRef >> boundVarNames, function_ref< ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, uint32_t weight, bool noPattern)
static LogicalResult verifyQuantifierRegions(QuantifierOp op)
static LogicalResult parseSameOperandTypeVariadicToBoolOp(OpAsmParser &parser, OperationState &result)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Returns the argument types of the first block within the region.
RetT walk(FnT &&callback)
Walk all nested operations, blocks or regions (including this region), depending on the type of callb...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
bool isAnyNonFuncSMTValueType(mlir::Type type)
Returns whether the given type is an SMT value type (excluding functions).
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.