9 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H 10 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H 15 #include "llvm/ADT/ScopeExit.h" 70 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 71 assert(!reported &&
"attempting to report a diagnostic more than once");
73 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 74 if (!diagnostics.empty()) {
75 for (
auto &&diagnostic : diagnostics) {
76 diagnostic.getLocation().getContext()->getDiagEngine().emit(
77 std::move(diagnostic));
102 for (
auto &diagnostic : diagnostics) {
103 res.append(diagnostic.str());
113 if (isSilenceableFailure())
114 return "silenceable failure";
115 return "definite failure";
121 if (!diagnostics.empty()) {
130 assert(!diagnostics.empty() &&
"expected a diagnostic to be present");
131 auto guard = llvm::make_scope_exit([&]() { diagnostics.clear(); });
132 return std::move(diagnostics);
137 template <
typename T>
139 assert(isSilenceableFailure() &&
140 "can only append output in silenceable failure state");
141 diagnostics.back() << std::forward<T>(
value);
144 template <
typename T>
146 return std::move(this->
operator<<(std::forward<T>(
value)));
152 assert(isSilenceableFailure() &&
153 "can only attach notes to silenceable failures");
154 return diagnostics.back().attachNote(loc);
160 diagnostics.emplace_back(std::move(diagnostic));
163 : diagnostics(std::move(diagnostics)), result(
failure()) {}
174 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 179 bool reported =
false;
180 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 183 namespace transform {
185 class TransformOpInterface;
197 expensiveChecksEnabled = enable;
205 bool expensiveChecksEnabled =
true;
280 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 282 "mapping block arguments from a region other than the active one");
283 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 284 return setPayloadOps(argument, operations);
307 state.mappings.erase(region);
308 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 309 state.regionStack.pop_back();
310 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 317 : state(state), region(®ion) {
318 auto res = state.mappings.try_emplace(this->region);
319 assert(res.second &&
"the region scope is already present");
321 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 322 assert(state.regionStack.back()->isProperAncestor(®ion) &&
323 "scope started at a non-nested region");
324 state.regionStack.push_back(®ion);
325 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 396 template <
typename Ty,
typename... Args>
400 "only an class derived from TransformState::Extension is allowed here");
401 auto ptr = std::make_unique<Ty>(*
this, std::forward<Args>(args)...);
402 auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
403 assert(result.second &&
"extension already added");
404 return *
static_cast<Ty *
>(result.first->second.get());
408 template <
typename Ty>
412 "only an class derived from TransformState::Extension is allowed here");
413 auto iter = extensions.find(TypeID::get<Ty>());
414 if (iter == extensions.end())
416 return static_cast<Ty *
>(iter->second.get());
420 template <
typename Ty>
424 "only an class derived from TransformState::Extension is allowed here");
425 extensions.erase(TypeID::get<Ty>());
430 static constexpr
Value kTopLevelValue =
Value();
433 const Mappings &getMapping(
Value value)
const {
436 Mappings &getMapping(
Value value) {
438 assert(it != mappings.end() &&
439 "trying to find a mapping for a value from an unmapped region");
444 const Mappings &getMapping(
Operation *operation)
const {
447 Mappings &getMapping(
Operation *operation) {
449 assert(it != mappings.end() &&
450 "trying to find a mapping for an operation from an unmapped region");
470 void removePayloadOps(
Value value);
479 updatePayloadOps(
Value value,
493 void recordHandleInvalidation(
OpOperand &handle);
501 checkAndRecordHandleInvalidation(TransformOpInterface transform);
505 llvm::SmallDenseMap<Region *, Mappings> mappings;
522 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 527 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 589 template <
typename OpTy>
601 return &this->getOperation()->getRegion(region).
front();
609 assert(region.
getParentOp() == this->getOperation() &&
610 "op comes from the wrong region");
612 state, this->getOperation(), region);
616 this->getOperation()->getNumRegions() == 1 &&
617 "must indicate the region to map if the operation has more than one");
618 return mapBlockArguments(state, this->getOperation()->getRegion(0));
642 template <
typename OpTy>
684 StringRef
getName()
override {
return "transform.mapping"; }
695 StringRef
getName()
override {
return "transform.payload_ir"; }
720 template <
typename OpTy>
737 <<
"FunctionalStyleTransformOpTrait should only be attached to ops " 738 "that implement MemoryEffectOpInterface";
747 template <
typename OpTy>
761 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
762 "expected single-operand op");
763 static_assert(OpTy::template hasTrait<OpTrait::OneResult>(),
764 "expected single-result op");
766 op->
emitError() <<
"NavigationTransformOpTrait should only be attached " 767 "to ops that implement MemoryEffectOpInterface";
776 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc" 779 namespace transform {
792 template <
typename FnTy>
797 using OpTy =
typename llvm::function_traits<FnTy>::template arg_t<0>;
799 "expected transform function to take an operation");
805 auto specificOp = dyn_cast<OpTy>(target);
808 diag <<
"transform applied to the wrong op kind";
809 diag.
attachNote(target->getLoc()) <<
"when applied to this op";
813 results.back().assign(expectedNumResults,
nullptr);
814 silenceableStack.push_back(std::move(diag));
823 silenceableStack.push_back(std::move(
diag));
825 if (!silenceableStack.empty()) {
827 std::move(silenceableStack));
838 int64_t rows = m.size(), cols = m[0].size();
839 for (int64_t
j = 0;
j < cols; ++
j)
841 for (int64_t i = 0; i < rows; ++i) {
842 assert(static_cast<int64_t>(m[i].size()) == cols);
843 for (int64_t
j = 0;
j < cols; ++
j) {
853 template <
typename OpTy>
857 using TransformOpType =
typename llvm::function_traits<
858 decltype(&OpTy::applyToOne)>::template arg_t<0>;
866 if (targets.empty()) {
868 for (
auto r : this->getOperation()->getResults())
869 transformResults.
set(r.template cast<OpResult>(), empty);
875 int expectedNumResults = this->getOperation()->getNumResults();
878 this->getOperation()->getLoc(), expectedNumResults, targets, results,
880 auto res =
static_cast<OpTy *
>(
this)->applyToOne(specificOp,
881 partialResult, state);
882 if (res.isDefiniteFailure())
887 if (static_cast<int>(partialResult.size()) != expectedNumResults) {
888 auto loc = this->getOperation()->getLoc();
890 << OpTy::getOperationName() <<
" expected to produce " 891 << expectedNumResults <<
" results (actually produced " 892 << partialResult.size() <<
").";
894 <<
"If you need variadic results, consider a generic `apply` " 895 <<
"instead of the specialized `applyToOne`.";
897 <<
"Producing " << expectedNumResults <<
" null results is " 898 <<
"allowed if the use case warrants it.";
899 diag.attachNote(specificOp->getLoc()) <<
"when applied to this op";
904 if (llvm::any_of(partialResult, [](
Operation *op) {
return op; }) &&
905 llvm::any_of(partialResult, [](
Operation *op) {
return !op; })) {
906 auto loc = this->getOperation()->getLoc();
908 << OpTy::getOperationName()
909 <<
" produces both null and non null results.";
910 diag.attachNote(specificOp->getLoc()) <<
"when applied to this op";
921 if (OpTy::template hasTrait<OpTrait::ZeroResults>())
930 if (OpTy::template hasTrait<OpTrait::OneResult>()) {
931 assert(transposedResults.size() == 1 &&
"Expected single result");
932 transformResults.
set(
933 this->getOperation()->getResult(0).
template cast<OpResult>(),
934 transposedResults[0]);
940 for (
const auto &it :
941 llvm::zip(this->getOperation()->getResults(), transposedResults)) {
943 llvm::copy_if(std::get<1>(it), std::back_inserter(filtered),
945 transformResults.
set(std::get<0>(it).
template cast<OpResult>(), filtered);
952 template <
typename OpTy>
955 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
956 "expected single-operand op");
958 return op->
emitError() <<
"TransformEachOpTrait should only be attached to " 959 "ops that implement TransformOpInterface";
965 #endif // DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H Diagnostic & attachNote(Optional< Location > loc=llvm::None)
Attaches a note to the last diagnostic.
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
static std::string diag(llvm::Value &v)
The result of a transform IR operation application.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary...
Operation is a basic unit of execution within MLIR.
This is a value defined by a result of an operation.
Block represents an ordered list of Operations.
Diagnostic & attachNote(Optional< Location > noteLoc=llvm::None)
Attaches a note to this diagnostic.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
SmallVector< Diagnostic > && takeDiagnostics()
Take the diagnostic and silence.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic, preserves the other states.
std::string getStatusString() const
Returns a string representation of the failure mode (for error reporting).
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Region * getParentRegion()
Return the Region in which this Value is defined.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
DiagnosedSilenceableFailure && operator<<(T &&value) &&
DiagnosedSilenceableFailure(LogicalResult result)
DiagnosedSilenceableFailure & operator<<(T &&value) &
Streams the given values into the last diagnotic.
This base class is used for derived effects that are non-parametric.
bool succeeded() const
Returns true if this is a success.
This class represents an argument of a Block.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.
static DiagnosedSilenceableFailure silenceableFailure(SmallVector< Diagnostic > &&diag)
Operation * getParentOp()
Return the parent operation this region is attached to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static llvm::ManagedStatic< PassManagerOptions > options
Helper class for implementing traits.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class represents an operand of an operation.
LogicalResult verifyTrait(ConcreteOp op)
This function defines the internal implementation of the verifyTrait method on FunctionOpInterface::T...
Region * getParentRegion()
Returns the region to which the instruction belongs.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class provides an abstraction over the different types of ranges over Values.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
std::string getMessage() const
Returns the diagnostic message without emitting it.