9 #ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
19 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"
24 class TransformOpInterface;
25 class TransformResults;
26 class TransformRewriter;
88 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"
105 expensiveChecksEnabled = enable;
111 enforceSingleToplevelTransformOp = enable;
120 return enforceSingleToplevelTransformOp;
124 bool expensiveChecksEnabled =
true;
125 bool enforceSingleToplevelTransformOp =
true;
135 Operation *payloadRoot, TransformOpInterface transform,
137 const TransformOptions &
options = TransformOptions(),
138 bool enforceToplevelTransformOp =
true,
139 function_ref<void(TransformState &)> stateInitializer =
nullptr,
140 function_ref<LogicalResult(TransformState &)> stateExporter =
nullptr);
199 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
214 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
215 TransformIRTimestampMapping timestamps;
216 void incrementTimestamp(
Value value) { ++timestamps[value]; }
241 return topLevelMappedValues[position];
251 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
256 int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
261 return llvm::make_filter_range(view, [=](
Operation *op) {
262 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
263 [[maybe_unused]]
bool sameTimestamp =
264 currentTimestamp == this->getMapping(value).timestamps.lookup(value);
265 assert(sameTimestamp &&
"iterator was invalidated during iteration");
267 return op !=
nullptr;
280 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
285 int64_t currentTimestamp =
286 getMapping(handleValue).timestamps.lookup(handleValue);
287 return llvm::make_filter_range(view, [=](
Value v) {
288 [[maybe_unused]]
bool sameTimestamp =
290 this->getMapping(handleValue).timestamps.lookup(handleValue);
291 assert(sameTimestamp &&
"iterator was invalidated during iteration");
295 return llvm::make_range(view.begin(), view.end());
305 bool includeOutOfScope =
false)
const;
313 bool includeOutOfScope =
false)
const;
329 "mapping block arguments from a region other than the active one");
330 return setPayloadOps(argument, operations);
363 : state(state), region(®ion) {
364 auto res = state.mappings.insert(
365 std::make_pair(®ion, std::make_unique<Mappings>()));
366 assert(res.second &&
"the region scope is already present");
368 state.regionStack.push_back(
this);
378 TransformOpInterface currentTransform;
451 template <
typename Ty,
typename... Args>
454 std::is_base_of<Extension, Ty>::value,
455 "only an class derived from TransformState::Extension is allowed here");
456 auto ptr = std::make_unique<Ty>(*
this, std::forward<Args>(args)...);
457 auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
458 assert(result.second &&
"extension already added");
459 return *
static_cast<Ty *
>(result.first->second.get());
463 template <
typename Ty>
466 std::is_base_of<Extension, Ty>::value,
467 "only an class derived from TransformState::Extension is allowed here");
468 auto iter = extensions.find(TypeID::get<Ty>());
469 if (iter == extensions.end())
471 return static_cast<Ty *
>(iter->second.get());
475 template <
typename Ty>
478 std::is_base_of<Extension, Ty>::value,
479 "only an class derived from TransformState::Extension is allowed here");
480 extensions.erase(TypeID::get<Ty>());
485 static constexpr
Value kTopLevelValue =
Value();
493 const TransformOptions &
options = TransformOptions());
498 const Mappings &getMapping(
Value value,
bool allowOutOfScope =
false)
const {
499 return const_cast<TransformState *
>(
this)->getMapping(value,
502 Mappings &getMapping(Value value,
bool allowOutOfScope =
false) {
504 auto it = mappings.find(region);
505 assert(it != mappings.end() &&
506 "trying to find a mapping for a value from an unmapped region");
508 if (!allowOutOfScope) {
509 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
512 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
513 llvm_unreachable(
"trying to get mapping beyond region that is "
514 "isolated from above");
524 const Mappings &getMapping(Operation *operation,
525 bool allowOutOfScope =
false)
const {
526 return const_cast<TransformState *
>(
this)->getMapping(operation,
529 Mappings &getMapping(Operation *operation,
bool allowOutOfScope =
false) {
530 Region *region = operation->getParentRegion();
531 auto it = mappings.find(region);
532 assert(it != mappings.end() &&
533 "trying to find a mapping for an operation from an unmapped region");
535 if (!allowOutOfScope) {
536 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
539 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
540 llvm_unreachable(
"trying to get mapping beyond region that is "
541 "isolated from above");
550 LogicalResult updateStateFromResults(
const TransformResults &results,
551 ResultRange opResults);
556 ArrayRef<Operation *> getPayloadOpsView(Value value)
const;
560 ArrayRef<Value> getPayloadValuesView(Value handleValue)
const;
597 LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
638 LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
644 LogicalResult setParams(Value value, ArrayRef<Param> params);
652 void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
653 bool allowOutOfScope =
false);
655 void forgetValueMapping(Value valueHandle,
656 ArrayRef<Operation *> payloadOperations);
664 LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
669 LogicalResult replacePayloadValue(Value value, Value replacement);
687 void recordOpHandleInvalidation(OpOperand &consumingHandle,
688 ArrayRef<Operation *> potentialAncestors,
690 InvalidatedHandleMap &newlyInvalidated)
const;
709 void recordOpHandleInvalidationOne(
710 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
711 Operation *payloadOp, Value otherHandle, Value throughValue,
712 InvalidatedHandleMap &newlyInvalidated)
const;
729 void recordValueHandleInvalidationByOpHandleOne(
730 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
731 Value payloadValue, Value valueHandle,
732 InvalidatedHandleMap &newlyInvalidated)
const;
748 recordValueHandleInvalidation(OpOperand &valueHandle,
749 InvalidatedHandleMap &newlyInvalidated)
const;
757 checkAndRecordHandleInvalidation(TransformOpInterface transform);
762 LogicalResult checkAndRecordHandleInvalidationImpl(
763 transform::TransformOpInterface transform,
764 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const;
767 void compactOpHandles();
774 llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
780 DenseSet<Value> opHandlesToCompact;
784 DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
792 RaggedArray<MappedValue> topLevelMappedValues;
801 InvalidatedHandleMap invalidatedHandles;
806 SmallVector<RegionScope *> regionStack;
810 std::unique_ptr<RegionScope> topLevelRegionScope;
823 template <
typename Range>
826 assert(position <
static_cast<int64_t
>(operations.size()) &&
827 "setting results for a non-existent handle");
828 assert(operations[position].data() ==
nullptr &&
"results already set");
829 assert(params[position].data() ==
nullptr &&
830 "another kind of results already set");
831 assert(values[position].data() ==
nullptr &&
832 "another kind of results already set");
833 operations.replace(position, std::forward<Range>(ops));
840 void set(
OpResult value, std::initializer_list<Operation *> ops) {
855 template <
typename Range>
858 assert(position <
static_cast<int64_t
>(this->values.size()) &&
859 "setting values for a non-existent handle");
860 assert(this->values[position].data() ==
nullptr &&
"values already set");
861 assert(operations[position].data() ==
nullptr &&
862 "another kind of results already set");
863 assert(params[position].data() ==
nullptr &&
864 "another kind of results already set");
865 this->values.replace(position, std::forward<Range>(values));
912 bool isParam(
unsigned resultNumber)
const;
917 bool isValue(
unsigned resultNumber)
const;
921 bool isSet(
unsigned resultNumber)
const;
1049 void notifyOperationErased(
Operation *op)
override;
1052 using Listener::notifyOperationReplaced;
1055 TransformOpInterface transformOp;
1099 int64_t errorCounter = 0;
1102 std::optional<Diagnostic> matchFailure;
1154 template <
typename OpTy>
1181 assert(region.
getParentOp() == this->getOperation() &&
1182 "op comes from the wrong region");
1189 "must indicate the region to map if the operation has more than one");
1194 class ApplyToEachResultList;
1216 template <
typename OpTy>
1259 StringRef
getName()
override {
return "transform.mapping"; }
1270 StringRef
getName()
override {
return "transform.payload_ir"; }
1303 Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1307 template <
typename OpTy>
1324 <<
"FunctionalStyleTransformOpTrait should only be attached to ops "
1325 "that implement MemoryEffectOpInterface";
1334 template <
typename OpTy>
1344 return isa<TransformHandleTypeInterface,
1345 TransformValueHandleTypeInterface>(t);
1354 op->
emitError() <<
"NavigationTransformOpTrait should only be attached "
1355 "to ops that implement MemoryEffectOpInterface";
1375 template <
typename OpTy>
1399 template <
typename OpTy>
1418 template <
typename Range>
1424 results.reserve(llvm::size(range));
1425 for (
auto element : range) {
1426 if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1428 results.push_back(
static_cast<Operation *
>(element));
1429 }
else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1431 results.push_back(element.template get<Value>());
1433 results.push_back(
static_cast<Attribute>(element));
1450 auto begin() {
return results.begin(); }
1451 auto end() {
return results.end(); }
1452 auto begin()
const {
return results.begin(); }
1453 auto end()
const {
return results.end(); }
1456 size_t size()
const {
return results.size(); }
1461 return results[index];
1475 const ApplyToEachResultList &partialResult);
1482 TransformResults &transformResults,
1496 template <
typename TransformOpTy,
typename Range>
1500 using OpTy =
typename llvm::function_traits<
1501 decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1502 static_assert(std::is_convertible<OpTy, Operation *>::value,
1503 "expected transform function to take an operation");
1507 unsigned expectedNumResults = transformOp->getNumResults();
1509 auto specificOp = dyn_cast<OpTy>(target);
1512 diag <<
"transform applied to the wrong op kind";
1513 diag.attachNote(target->getLoc()) <<
"when applied to this op";
1514 silenceableStack.push_back(std::move(
diag));
1519 partialResults.
reserve(expectedNumResults);
1520 Location specificOpLoc = specificOp->getLoc();
1523 transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1536 results.push_back(std::move(partialResults));
1538 if (!silenceableStack.empty()) {
1540 std::move(silenceableStack));
1555 template <
typename OpTy>
1560 Value handle = this->getOperation()->getOperand(0);
1561 auto targets = state.getPayloadOps(handle);
1565 if (state.getOptions().getExpensiveChecksEnabled() &&
1567 this->getOperation())) &&
1569 llvm::to_vector(targets)))) {
1570 return DiagnosedSilenceableFailure::definiteFailure();
1577 if (std::empty(targets)) {
1580 for (
OpResult r : this->getOperation()->getResults()) {
1581 if (isa<TransformParamTypeInterface>(r.getType()))
1582 transformResults.
setParams(r, emptyParams);
1583 else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1586 transformResults.
set(r, emptyPayload);
1588 return DiagnosedSilenceableFailure::success();
1595 cast<OpTy>(this->getOperation()), rewriter, targets, results, state);
1611 template <
typename OpTy>
1614 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1615 "expected single-operand op");
1617 return op->
emitError() <<
"TransformEachOpTrait should only be attached to "
1618 "ops that implement TransformOpInterface";
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
void takeDiagnostics(SmallVectorImpl< Diagnostic > &diags)
Take the diagnostics and silence.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Helper class for implementing traits.
Operation * getOperation()
Return the ultimate Operation being worked on.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
A 2D array where each row may have different length.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This base class is used for derived effects that are non-parametric.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Region * getParentRegion()
Return the Region in which this Value is defined.
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...