9 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMINTERFACES_H
24 class TransformOpInterface;
25 class TransformResults;
26 class TransformRewriter;
77 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h.inc"
94 expensiveChecksEnabled = enable;
100 enforceSingleToplevelTransformOp = enable;
109 return enforceSingleToplevelTransformOp;
113 bool expensiveChecksEnabled =
true;
114 bool enforceSingleToplevelTransformOp =
true;
126 const TransformOptions &
options = TransformOptions(),
127 bool enforceToplevelTransformOp =
true);
186 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
201 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
202 TransformIRTimestampMapping timestamps;
203 void incrementTimestamp(
Value value) { ++timestamps[value]; }
226 return topLevelMappedValues[position];
236 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
241 int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
246 return llvm::make_filter_range(view, [=](
Operation *op) {
247 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
248 [[maybe_unused]]
bool sameTimestamp =
249 currentTimestamp == this->getMapping(value).timestamps.lookup(value);
250 assert(sameTimestamp &&
"iterator was invalidated during iteration");
252 return op !=
nullptr;
265 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
270 int64_t currentTimestamp =
271 getMapping(handleValue).timestamps.lookup(handleValue);
272 return llvm::make_filter_range(view, [=](
Value v) {
273 [[maybe_unused]]
bool sameTimestamp =
275 this->getMapping(handleValue).timestamps.lookup(handleValue);
276 assert(sameTimestamp &&
"iterator was invalidated during iteration");
280 return llvm::make_range(view.begin(), view.end());
290 bool includeOutOfScope =
false)
const;
298 bool includeOutOfScope =
false)
const;
313 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
315 "mapping block arguments from a region other than the active one");
317 return setPayloadOps(argument, operations);
348 : state(state), region(®ion) {
349 auto res = state.mappings.insert(
350 std::make_pair(®ion, std::make_unique<Mappings>()));
351 assert(res.second &&
"the region scope is already present");
353 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
354 state.regionStack.push_back(®ion);
435 template <
typename Ty,
typename... Args>
438 std::is_base_of<Extension, Ty>::value,
439 "only an class derived from TransformState::Extension is allowed here");
440 auto ptr = std::make_unique<Ty>(*
this, std::forward<Args>(args)...);
441 auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
442 assert(result.second &&
"extension already added");
443 return *
static_cast<Ty *
>(result.first->second.get());
447 template <
typename Ty>
450 std::is_base_of<Extension, Ty>::value,
451 "only an class derived from TransformState::Extension is allowed here");
452 auto iter = extensions.find(TypeID::get<Ty>());
453 if (iter == extensions.end())
455 return static_cast<Ty *
>(iter->second.get());
459 template <
typename Ty>
462 std::is_base_of<Extension, Ty>::value,
463 "only an class derived from TransformState::Extension is allowed here");
464 extensions.erase(TypeID::get<Ty>());
469 static constexpr
Value kTopLevelValue =
Value();
477 const TransformOptions &
options = TransformOptions());
482 const Mappings &getMapping(
Value value,
bool allowOutOfScope =
false)
const {
483 return const_cast<TransformState *
>(
this)->getMapping(value,
486 Mappings &getMapping(Value value,
bool allowOutOfScope =
false) {
488 auto it = mappings.find(region);
489 assert(it != mappings.end() &&
490 "trying to find a mapping for a value from an unmapped region");
492 if (!allowOutOfScope) {
493 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
496 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
497 llvm_unreachable(
"trying to get mapping beyond region that is "
498 "isolated from above");
508 const Mappings &getMapping(Operation *operation,
509 bool allowOutOfScope =
false)
const {
510 return const_cast<TransformState *
>(
this)->getMapping(operation,
513 Mappings &getMapping(Operation *operation,
bool allowOutOfScope =
false) {
514 Region *region = operation->getParentRegion();
515 auto it = mappings.find(region);
516 assert(it != mappings.end() &&
517 "trying to find a mapping for an operation from an unmapped region");
519 if (!allowOutOfScope) {
520 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
523 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
524 llvm_unreachable(
"trying to get mapping beyond region that is "
525 "isolated from above");
534 LogicalResult updateStateFromResults(
const TransformResults &results,
535 ResultRange opResults);
540 ArrayRef<Operation *> getPayloadOpsView(Value value)
const;
544 ArrayRef<Value> getPayloadValuesView(Value handleValue)
const;
581 LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
622 LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
628 LogicalResult setParams(Value value, ArrayRef<Param> params);
636 void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
637 bool allowOutOfScope =
false);
639 void forgetValueMapping(Value valueHandle,
640 ArrayRef<Operation *> payloadOperations);
648 LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
653 LogicalResult replacePayloadValue(Value value, Value replacement);
671 void recordOpHandleInvalidation(OpOperand &consumingHandle,
672 ArrayRef<Operation *> potentialAncestors,
674 InvalidatedHandleMap &newlyInvalidated)
const;
693 void recordOpHandleInvalidationOne(
694 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
695 Operation *payloadOp, Value otherHandle, Value throughValue,
696 InvalidatedHandleMap &newlyInvalidated)
const;
713 void recordValueHandleInvalidationByOpHandleOne(
714 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
715 Value payloadValue, Value valueHandle,
716 InvalidatedHandleMap &newlyInvalidated)
const;
732 recordValueHandleInvalidation(OpOperand &valueHandle,
733 InvalidatedHandleMap &newlyInvalidated)
const;
741 checkAndRecordHandleInvalidation(TransformOpInterface transform);
746 LogicalResult checkAndRecordHandleInvalidationImpl(
747 transform::TransformOpInterface transform,
748 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const;
751 void compactOpHandles();
758 llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
764 DenseSet<Value> opHandlesToCompact;
768 DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
776 RaggedArray<MappedValue> topLevelMappedValues;
785 InvalidatedHandleMap invalidatedHandles;
787 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
791 SmallVector<Region *> regionStack;
805 template <
typename Range>
808 assert(position <
static_cast<int64_t
>(operations.size()) &&
809 "setting results for a non-existent handle");
810 assert(operations[position].data() ==
nullptr &&
"results already set");
811 assert(params[position].data() ==
nullptr &&
812 "another kind of results already set");
813 assert(values[position].data() ==
nullptr &&
814 "another kind of results already set");
815 operations.replace(position, std::forward<Range>(ops));
822 void set(
OpResult value, std::initializer_list<Operation *> ops) {
837 template <
typename Range>
840 assert(position <
static_cast<int64_t
>(this->values.size()) &&
841 "setting values for a non-existent handle");
842 assert(this->values[position].data() ==
nullptr &&
"values already set");
843 assert(operations[position].data() ==
nullptr &&
844 "another kind of results already set");
845 assert(params[position].data() ==
nullptr &&
846 "another kind of results already set");
847 this->values.replace(position, std::forward<Range>(values));
894 bool isParam(
unsigned resultNumber)
const;
899 bool isValue(
unsigned resultNumber)
const;
903 bool isSet(
unsigned resultNumber)
const;
1008 void notifyOperationRemoved(
Operation *op)
override;
1011 using Listener::notifyOperationReplaced;
1014 TransformOpInterface transformOp;
1047 int64_t errorCounter = 0;
1099 template <
typename OpTy>
1126 assert(region.
getParentOp() == this->getOperation() &&
1127 "op comes from the wrong region");
1134 "must indicate the region to map if the operation has more than one");
1139 class ApplyToEachResultList;
1161 template <
typename OpTy>
1204 StringRef
getName()
override {
return "transform.mapping"; }
1215 StringRef
getName()
override {
return "transform.payload_ir"; }
1246 Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1250 template <
typename OpTy>
1265 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1267 <<
"FunctionalStyleTransformOpTrait should only be attached to ops "
1268 "that implement MemoryEffectOpInterface";
1277 template <
typename OpTy>
1291 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1292 "expected single-operand op");
1293 static_assert(OpTy::template hasTrait<OpTrait::OneResult>(),
1294 "expected single-result op");
1295 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1296 op->emitError() <<
"NavigationTransformOpTrait should only be attached "
1297 "to ops that implement MemoryEffectOpInterface";
1317 template <
typename OpTy>
1341 template <
typename OpTy>
1360 template <
typename Range>
1366 results.reserve(llvm::size(range));
1367 for (
auto element : range) {
1368 if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1370 results.push_back(
static_cast<Operation *
>(element));
1371 }
else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1373 results.push_back(element.template get<Value>());
1375 results.push_back(
static_cast<Attribute>(element));
1392 auto begin() {
return results.begin(); }
1393 auto end() {
return results.end(); }
1394 auto begin()
const {
return results.begin(); }
1395 auto end()
const {
return results.end(); }
1398 size_t size()
const {
return results.size(); }
1403 return results[index];
1417 const ApplyToEachResultList &partialResult);
1424 TransformResults &transformResults,
1438 template <
typename TransformOpTy,
typename Range>
1442 using OpTy =
typename llvm::function_traits<
1443 decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1444 static_assert(std::is_convertible<OpTy, Operation *>::value,
1445 "expected transform function to take an operation");
1449 unsigned expectedNumResults = transformOp->getNumResults();
1451 auto specificOp = dyn_cast<OpTy>(target);
1454 diag <<
"transform applied to the wrong op kind";
1455 diag.attachNote(target->getLoc()) <<
"when applied to this op";
1456 silenceableStack.push_back(std::move(
diag));
1461 partialResults.
reserve(expectedNumResults);
1462 Location specificOpLoc = specificOp->getLoc();
1465 transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1478 results.push_back(std::move(partialResults));
1480 if (!silenceableStack.empty()) {
1482 std::move(silenceableStack));
1497 template <
typename OpTy>
1502 Value handle = this->getOperation()->getOperand(0);
1503 auto targets = state.getPayloadOps(handle);
1507 if (state.getOptions().getExpensiveChecksEnabled() &&
1509 this->getOperation())) &&
1511 llvm::to_vector(targets)))) {
1512 return DiagnosedSilenceableFailure::definiteFailure();
1519 if (std::empty(targets)) {
1522 for (
OpResult r : this->getOperation()->getResults()) {
1523 if (isa<TransformParamTypeInterface>(r.getType()))
1524 transformResults.
setParams(r, emptyParams);
1525 else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1528 transformResults.
set(r, emptyPayload);
1537 cast<OpTy>(this->getOperation()), rewriter, targets, results, state);
1553 template <
typename OpTy>
1556 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1557 "expected single-operand op");
1558 if (!op->getName().getInterface<TransformOpInterface>()) {
1559 return op->emitError() <<
"TransformEachOpTrait should only be attached to "
1560 "ops that implement TransformOpInterface";
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
void takeDiagnostics(SmallVectorImpl< Diagnostic > &diags)
Take the diagnostics and silence.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Helper class for implementing traits.
Operation * getOperation()
Return the ultimate Operation being worked on.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A 2D array where each row may have different length.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This base class is used for derived effects that are non-parametric.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Region * getParentRegion()
Return the Region in which this Value is defined.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...