9 #ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
20 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"
25 class TransformOpInterface;
26 class TransformResults;
27 class TransformRewriter;
78 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"
95 expensiveChecksEnabled = enable;
101 enforceSingleToplevelTransformOp = enable;
110 return enforceSingleToplevelTransformOp;
114 bool expensiveChecksEnabled =
true;
115 bool enforceSingleToplevelTransformOp =
true;
127 const TransformOptions &
options = TransformOptions(),
128 bool enforceToplevelTransformOp =
true);
187 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
202 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
203 TransformIRTimestampMapping timestamps;
204 void incrementTimestamp(
Value value) { ++timestamps[value]; }
227 return topLevelMappedValues[position];
237 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
242 int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
247 return llvm::make_filter_range(view, [=](
Operation *op) {
248 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
249 [[maybe_unused]]
bool sameTimestamp =
250 currentTimestamp == this->getMapping(value).timestamps.lookup(value);
251 assert(sameTimestamp &&
"iterator was invalidated during iteration");
253 return op !=
nullptr;
266 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
271 int64_t currentTimestamp =
272 getMapping(handleValue).timestamps.lookup(handleValue);
273 return llvm::make_filter_range(view, [=](
Value v) {
274 [[maybe_unused]]
bool sameTimestamp =
276 this->getMapping(handleValue).timestamps.lookup(handleValue);
277 assert(sameTimestamp &&
"iterator was invalidated during iteration");
281 return llvm::make_range(view.begin(), view.end());
291 bool includeOutOfScope =
false)
const;
299 bool includeOutOfScope =
false)
const;
315 "mapping block arguments from a region other than the active one");
316 return setPayloadOps(argument, operations);
347 : state(state), region(®ion) {
348 auto res = state.mappings.insert(
349 std::make_pair(®ion, std::make_unique<Mappings>()));
350 assert(res.second &&
"the region scope is already present");
352 state.regionStack.push_back(
this);
362 TransformOpInterface currentTransform;
435 template <
typename Ty,
typename... Args>
438 std::is_base_of<Extension, Ty>::value,
439 "only an class derived from TransformState::Extension is allowed here");
440 auto ptr = std::make_unique<Ty>(*
this, std::forward<Args>(args)...);
441 auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
442 assert(result.second &&
"extension already added");
443 return *
static_cast<Ty *
>(result.first->second.get());
447 template <
typename Ty>
450 std::is_base_of<Extension, Ty>::value,
451 "only an class derived from TransformState::Extension is allowed here");
452 auto iter = extensions.find(TypeID::get<Ty>());
453 if (iter == extensions.end())
455 return static_cast<Ty *
>(iter->second.get());
459 template <
typename Ty>
462 std::is_base_of<Extension, Ty>::value,
463 "only an class derived from TransformState::Extension is allowed here");
464 extensions.erase(TypeID::get<Ty>());
469 static constexpr
Value kTopLevelValue =
Value();
477 const TransformOptions &
options = TransformOptions());
482 const Mappings &getMapping(
Value value,
bool allowOutOfScope =
false)
const {
483 return const_cast<TransformState *
>(
this)->getMapping(value,
486 Mappings &getMapping(Value value,
bool allowOutOfScope =
false) {
488 auto it = mappings.find(region);
489 assert(it != mappings.end() &&
490 "trying to find a mapping for a value from an unmapped region");
492 if (!allowOutOfScope) {
493 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
496 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
497 llvm_unreachable(
"trying to get mapping beyond region that is "
498 "isolated from above");
508 const Mappings &getMapping(Operation *operation,
509 bool allowOutOfScope =
false)
const {
510 return const_cast<TransformState *
>(
this)->getMapping(operation,
513 Mappings &getMapping(Operation *operation,
bool allowOutOfScope =
false) {
514 Region *region = operation->getParentRegion();
515 auto it = mappings.find(region);
516 assert(it != mappings.end() &&
517 "trying to find a mapping for an operation from an unmapped region");
519 if (!allowOutOfScope) {
520 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
523 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
524 llvm_unreachable(
"trying to get mapping beyond region that is "
525 "isolated from above");
534 LogicalResult updateStateFromResults(
const TransformResults &results,
535 ResultRange opResults);
540 ArrayRef<Operation *> getPayloadOpsView(Value value)
const;
544 ArrayRef<Value> getPayloadValuesView(Value handleValue)
const;
581 LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
622 LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
628 LogicalResult setParams(Value value, ArrayRef<Param> params);
636 void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
637 bool allowOutOfScope =
false);
639 void forgetValueMapping(Value valueHandle,
640 ArrayRef<Operation *> payloadOperations);
648 LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
653 LogicalResult replacePayloadValue(Value value, Value replacement);
671 void recordOpHandleInvalidation(OpOperand &consumingHandle,
672 ArrayRef<Operation *> potentialAncestors,
674 InvalidatedHandleMap &newlyInvalidated)
const;
693 void recordOpHandleInvalidationOne(
694 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
695 Operation *payloadOp, Value otherHandle, Value throughValue,
696 InvalidatedHandleMap &newlyInvalidated)
const;
713 void recordValueHandleInvalidationByOpHandleOne(
714 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
715 Value payloadValue, Value valueHandle,
716 InvalidatedHandleMap &newlyInvalidated)
const;
732 recordValueHandleInvalidation(OpOperand &valueHandle,
733 InvalidatedHandleMap &newlyInvalidated)
const;
741 checkAndRecordHandleInvalidation(TransformOpInterface transform);
746 LogicalResult checkAndRecordHandleInvalidationImpl(
747 transform::TransformOpInterface transform,
748 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const;
751 void compactOpHandles();
758 llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
764 DenseSet<Value> opHandlesToCompact;
768 DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
776 RaggedArray<MappedValue> topLevelMappedValues;
785 InvalidatedHandleMap invalidatedHandles;
790 SmallVector<RegionScope *> regionStack;
794 std::unique_ptr<RegionScope> topLevelRegionScope;
807 template <
typename Range>
810 assert(position <
static_cast<int64_t
>(operations.size()) &&
811 "setting results for a non-existent handle");
812 assert(operations[position].data() ==
nullptr &&
"results already set");
813 assert(params[position].data() ==
nullptr &&
814 "another kind of results already set");
815 assert(values[position].data() ==
nullptr &&
816 "another kind of results already set");
817 operations.replace(position, std::forward<Range>(ops));
824 void set(
OpResult value, std::initializer_list<Operation *> ops) {
839 template <
typename Range>
842 assert(position <
static_cast<int64_t
>(this->values.size()) &&
843 "setting values for a non-existent handle");
844 assert(this->values[position].data() ==
nullptr &&
"values already set");
845 assert(operations[position].data() ==
nullptr &&
846 "another kind of results already set");
847 assert(params[position].data() ==
nullptr &&
848 "another kind of results already set");
849 this->values.replace(position, std::forward<Range>(values));
896 bool isParam(
unsigned resultNumber)
const;
901 bool isValue(
unsigned resultNumber)
const;
905 bool isSet(
unsigned resultNumber)
const;
1033 void notifyOperationErased(
Operation *op)
override;
1036 using Listener::notifyOperationReplaced;
1039 TransformOpInterface transformOp;
1075 int64_t errorCounter = 0;
1127 template <
typename OpTy>
1154 assert(region.
getParentOp() == this->getOperation() &&
1155 "op comes from the wrong region");
1162 "must indicate the region to map if the operation has more than one");
1167 class ApplyToEachResultList;
1189 template <
typename OpTy>
1232 StringRef
getName()
override {
return "transform.mapping"; }
1243 StringRef
getName()
override {
return "transform.payload_ir"; }
1274 Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1278 template <
typename OpTy>
1293 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1295 <<
"FunctionalStyleTransformOpTrait should only be attached to ops "
1296 "that implement MemoryEffectOpInterface";
1305 template <
typename OpTy>
1315 return isa<TransformHandleTypeInterface,
1316 TransformValueHandleTypeInterface>(t);
1324 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1325 op->emitError() <<
"NavigationTransformOpTrait should only be attached "
1326 "to ops that implement MemoryEffectOpInterface";
1346 template <
typename OpTy>
1370 template <
typename OpTy>
1389 template <
typename Range>
1395 results.reserve(llvm::size(range));
1396 for (
auto element : range) {
1397 if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1399 results.push_back(
static_cast<Operation *
>(element));
1400 }
else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1402 results.push_back(element.template get<Value>());
1404 results.push_back(
static_cast<Attribute>(element));
1421 auto begin() {
return results.begin(); }
1422 auto end() {
return results.end(); }
1423 auto begin()
const {
return results.begin(); }
1424 auto end()
const {
return results.end(); }
1427 size_t size()
const {
return results.size(); }
1432 return results[index];
1446 const ApplyToEachResultList &partialResult);
1453 TransformResults &transformResults,
1467 template <
typename TransformOpTy,
typename Range>
1471 using OpTy =
typename llvm::function_traits<
1472 decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1473 static_assert(std::is_convertible<OpTy, Operation *>::value,
1474 "expected transform function to take an operation");
1478 unsigned expectedNumResults = transformOp->getNumResults();
1480 auto specificOp = dyn_cast<OpTy>(target);
1483 diag <<
"transform applied to the wrong op kind";
1484 diag.attachNote(target->getLoc()) <<
"when applied to this op";
1485 silenceableStack.push_back(std::move(
diag));
1490 partialResults.
reserve(expectedNumResults);
1491 Location specificOpLoc = specificOp->getLoc();
1494 transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1507 results.push_back(std::move(partialResults));
1509 if (!silenceableStack.empty()) {
1511 std::move(silenceableStack));
1526 template <
typename OpTy>
1531 Value handle = this->getOperation()->getOperand(0);
1532 auto targets = state.getPayloadOps(handle);
1536 if (state.getOptions().getExpensiveChecksEnabled() &&
1538 this->getOperation())) &&
1540 llvm::to_vector(targets)))) {
1541 return DiagnosedSilenceableFailure::definiteFailure();
1548 if (std::empty(targets)) {
1551 for (
OpResult r : this->getOperation()->getResults()) {
1552 if (isa<TransformParamTypeInterface>(r.getType()))
1553 transformResults.
setParams(r, emptyParams);
1554 else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1557 transformResults.
set(r, emptyPayload);
1566 cast<OpTy>(this->getOperation()), rewriter, targets, results, state);
1582 template <
typename OpTy>
1585 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1586 "expected single-operand op");
1587 if (!op->getName().getInterface<TransformOpInterface>()) {
1588 return op->emitError() <<
"TransformEachOpTrait should only be attached to "
1589 "ops that implement TransformOpInterface";
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
void takeDiagnostics(SmallVectorImpl< Diagnostic > &diags)
Take the diagnostics and silence.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Helper class for implementing traits.
Operation * getOperation()
Return the ultimate Operation being worked on.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A 2D array where each row may have different length.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This base class is used for derived effects that are non-parametric.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Region * getParentRegion()
Return the Region in which this Value is defined.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...