MLIR  22.0.0git
SMTExtensionOps.cpp
Go to the documentation of this file.
1 //===- SMTExtensionOps.cpp - SMT extension for the Transform dialect ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 
14 using namespace mlir;
15 
16 #define GET_OP_CLASSES
17 #include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
18 
19 //===----------------------------------------------------------------------===//
20 // ConstrainParamsOp
21 //===----------------------------------------------------------------------===//
22 
23 void transform::smt::ConstrainParamsOp::getEffects(
25  onlyReadsHandle(getParamsMutable(), effects);
26  producesHandle(getResults(), effects);
27 }
28 
30 transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
33  // TODO: Proper operational semantics are to check the SMT problem in the body
34  // with a SMT solver with the arguments of the body constrained to the
35  // values passed into the op. Success or failure is then determined by
36  // the solver's result.
37  // One way to support this is to just promise the TransformOpInterface
38  // and allow for users to attach their own implementation, which would,
39  // e.g., translate the ops to SMTLIB and hand that over to the user's
40  // favourite solver. This requires changes to the dialect's verifier.
41  return emitSilenceableFailure(getLoc())
42  << "op does not have interpreted semantics yet";
43 }
44 
46  auto yieldTerminator =
47  dyn_cast<mlir::smt::YieldOp>(getRegion().front().back());
48  if (!yieldTerminator)
49  return emitOpError() << "expected '"
50  << mlir::smt::YieldOp::getOperationName()
51  << "' as terminator";
52 
53  auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc,
54  Type paramType, StringRef paramDesc,
55  auto *atOp) -> InFlightDiagnostic {
56  if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>(
57  smtType))
58  return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx
59  << " is expected to be either a !smt.bool, a "
60  "!smt.int, or a !smt.bv";
61 
62  assert(isa<TransformParamTypeInterface>(paramType) &&
63  "ODS specifies params' type should implement param interface");
64  if (isa<transform::AnyParamType>(paramType))
65  return {}; // No further checks can be done.
66 
67  // NB: This cast must succeed as long as the only implementors of
68  // TransformParamTypeInterface are AnyParamType and ParamType.
69  Type typeWrappedByParam = cast<ParamType>(paramType).getType();
70 
71  if (isa<mlir::smt::IntType>(smtType)) {
72  if (!isa<IntegerType>(typeWrappedByParam))
73  return atOp->emitOpError()
74  << "the type of " << smtDesc << " #" << idx
75  << " is !smt.int though the corresponding " << paramDesc
76  << " type (" << paramType << ") is not wrapping an integer type";
77  } else if (isa<mlir::smt::BoolType>(smtType)) {
78  auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
79  if (!wrappedIntType || wrappedIntType.getWidth() != 1)
80  return atOp->emitOpError()
81  << "the type of " << smtDesc << " #" << idx
82  << " is !smt.bool though the corresponding " << paramDesc
83  << " type (" << paramType << ") is not wrapping i1";
84  } else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) {
85  auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
86  if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth())
87  return atOp->emitOpError()
88  << "the type of " << smtDesc << " #" << idx << " is " << smtType
89  << " though the corresponding " << paramDesc << " type ("
90  << paramType
91  << ") is not wrapping an integer type of the same bitwidth";
92  }
93 
94  return {};
95  };
96 
97  if (getOperands().size() != getBody().getNumArguments())
98  return emitOpError(
99  "must have the same number of block arguments as operands");
100 
101  for (auto [idx, operandType, blockArgType] :
102  llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) {
103  InFlightDiagnostic typeCheckResult =
104  checkTypes(idx, blockArgType, "block arg", operandType, "operand",
105  /*atOp=*/this);
106  if (LogicalResult(typeCheckResult).failed())
107  return typeCheckResult;
108  }
109 
110  for (auto &op : getBody().getOps()) {
111  if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
112  return emitOpError(
113  "ops contained in region should belong to SMT-dialect");
114  }
115 
116  if (yieldTerminator->getNumOperands() != getNumResults())
117  return yieldTerminator.emitOpError()
118  << "expected terminator to have as many operands as the parent op "
119  "has results";
120 
121  for (auto [idx, termOperandType, resultType] : llvm::enumerate(
122  yieldTerminator->getOperands().getType(), getResultTypes())) {
123  InFlightDiagnostic typeCheckResult =
124  checkTypes(idx, termOperandType, "terminator operand",
125  cast<transform::ParamType>(resultType), "result",
126  /*atOp=*/&yieldTerminator);
127  if (LogicalResult(typeCheckResult).failed())
128  return typeCheckResult;
129  }
130 
131  return success();
132 }
The result of a transform IR operation application.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:316
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Include the generated interface declarations.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423