9 #ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
19 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"
24 class TransformOpInterface;
25 class TransformResults;
26 class TransformRewriter;
88 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"
105 expensiveChecksEnabled = enable;
111 enforceSingleToplevelTransformOp = enable;
120 return enforceSingleToplevelTransformOp;
124 bool expensiveChecksEnabled =
true;
125 bool enforceSingleToplevelTransformOp =
true;
135 Operation *payloadRoot, TransformOpInterface transform,
137 const TransformOptions &
options = TransformOptions(),
138 bool enforceToplevelTransformOp =
true,
139 function_ref<void(TransformState &)> stateInitializer =
nullptr,
140 function_ref<LogicalResult(TransformState &)> stateExporter =
nullptr);
199 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
214 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
215 TransformIRTimestampMapping timestamps;
216 void incrementTimestamp(
Value value) { ++timestamps[value]; }
241 return topLevelMappedValues[position];
251 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
256 int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
261 return llvm::make_filter_range(view, [=](
Operation *op) {
262 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
263 [[maybe_unused]]
bool sameTimestamp =
264 currentTimestamp == this->getMapping(value).timestamps.lookup(value);
265 assert(sameTimestamp &&
"iterator was invalidated during iteration");
267 return op !=
nullptr;
280 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
285 int64_t currentTimestamp =
286 getMapping(handleValue).timestamps.lookup(handleValue);
287 return llvm::make_filter_range(view, [=](
Value v) {
288 [[maybe_unused]]
bool sameTimestamp =
290 this->getMapping(handleValue).timestamps.lookup(handleValue);
291 assert(sameTimestamp &&
"iterator was invalidated during iteration");
295 return llvm::make_range(view.begin(), view.end());
305 bool includeOutOfScope =
false)
const;
313 bool includeOutOfScope =
false)
const;
329 "mapping block arguments from a region other than the active one");
330 return setPayloadOps(argument, operations);
363 : state(state), region(®ion) {
364 auto res = state.mappings.insert(
365 std::make_pair(®ion, std::make_unique<Mappings>()));
366 assert(res.second &&
"the region scope is already present");
368 state.regionStack.push_back(
this);
378 TransformOpInterface currentTransform;
451 template <
typename Ty,
typename... Args>
454 std::is_base_of<Extension, Ty>::value,
455 "only an class derived from TransformState::Extension is allowed here");
456 auto ptr = std::make_unique<Ty>(*
this, std::forward<Args>(args)...);
457 auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
458 assert(result.second &&
"extension already added");
459 return *
static_cast<Ty *
>(result.first->second.get());
463 template <
typename Ty>
466 std::is_base_of<Extension, Ty>::value,
467 "only an class derived from TransformState::Extension is allowed here");
468 auto iter = extensions.find(TypeID::get<Ty>());
469 if (iter == extensions.end())
471 return static_cast<Ty *
>(iter->second.get());
475 template <
typename Ty>
478 std::is_base_of<Extension, Ty>::value,
479 "only an class derived from TransformState::Extension is allowed here");
480 extensions.erase(TypeID::get<Ty>());
485 static constexpr
Value kTopLevelValue =
Value();
493 const TransformOptions &
options = TransformOptions());
498 const Mappings &getMapping(
Value value,
bool allowOutOfScope =
false)
const {
499 return const_cast<TransformState *
>(
this)->getMapping(value,
502 Mappings &getMapping(Value value,
bool allowOutOfScope =
false) {
504 auto it = mappings.find(region);
505 assert(it != mappings.end() &&
506 "trying to find a mapping for a value from an unmapped region");
508 if (!allowOutOfScope) {
509 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
512 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
513 llvm_unreachable(
"trying to get mapping beyond region that is "
514 "isolated from above");
524 const Mappings &getMapping(Operation *operation,
525 bool allowOutOfScope =
false)
const {
526 return const_cast<TransformState *
>(
this)->getMapping(operation,
529 Mappings &getMapping(Operation *operation,
bool allowOutOfScope =
false) {
530 Region *region = operation->getParentRegion();
531 auto it = mappings.find(region);
532 assert(it != mappings.end() &&
533 "trying to find a mapping for an operation from an unmapped region");
535 if (!allowOutOfScope) {
536 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
539 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
540 llvm_unreachable(
"trying to get mapping beyond region that is "
541 "isolated from above");
550 LogicalResult updateStateFromResults(
const TransformResults &results,
551 ResultRange opResults);
556 ArrayRef<Operation *> getPayloadOpsView(Value value)
const;
560 ArrayRef<Value> getPayloadValuesView(Value handleValue)
const;
597 LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
638 LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
644 LogicalResult setParams(Value value, ArrayRef<Param> params);
652 void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
653 bool allowOutOfScope =
false);
655 void forgetValueMapping(Value valueHandle,
656 ArrayRef<Operation *> payloadOperations);
664 LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
669 LogicalResult replacePayloadValue(Value value, Value replacement);
687 void recordOpHandleInvalidation(OpOperand &consumingHandle,
688 ArrayRef<Operation *> potentialAncestors,
690 InvalidatedHandleMap &newlyInvalidated)
const;
709 void recordOpHandleInvalidationOne(
710 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
711 Operation *payloadOp, Value otherHandle, Value throughValue,
712 InvalidatedHandleMap &newlyInvalidated)
const;
729 void recordValueHandleInvalidationByOpHandleOne(
730 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
731 Value payloadValue, Value valueHandle,
732 InvalidatedHandleMap &newlyInvalidated)
const;
748 recordValueHandleInvalidation(OpOperand &valueHandle,
749 InvalidatedHandleMap &newlyInvalidated)
const;
757 checkAndRecordHandleInvalidation(TransformOpInterface transform);
762 LogicalResult checkAndRecordHandleInvalidationImpl(
763 transform::TransformOpInterface transform,
764 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const;
767 void compactOpHandles();
774 llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
780 DenseSet<Value> opHandlesToCompact;
784 DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
792 RaggedArray<MappedValue> topLevelMappedValues;
801 InvalidatedHandleMap invalidatedHandles;
806 SmallVector<RegionScope *> regionStack;
810 std::unique_ptr<RegionScope> topLevelRegionScope;
823 template <
typename Range>
826 assert(position <
static_cast<int64_t
>(operations.size()) &&
827 "setting results for a non-existent handle");
828 assert(operations[position].data() ==
nullptr &&
"results already set");
829 assert(params[position].data() ==
nullptr &&
830 "another kind of results already set");
831 assert(values[position].data() ==
nullptr &&
832 "another kind of results already set");
833 operations.replace(position, std::forward<Range>(ops));
840 void set(
OpResult value, std::initializer_list<Operation *> ops) {
855 template <
typename Range>
858 assert(position <
static_cast<int64_t
>(this->values.size()) &&
859 "setting values for a non-existent handle");
860 assert(this->values[position].data() ==
nullptr &&
"values already set");
861 assert(operations[position].data() ==
nullptr &&
862 "another kind of results already set");
863 assert(params[position].data() ==
nullptr &&
864 "another kind of results already set");
865 this->values.replace(position, std::forward<Range>(values));
912 bool isParam(
unsigned resultNumber)
const;
917 bool isValue(
unsigned resultNumber)
const;
921 bool isSet(
unsigned resultNumber)
const;
1049 void notifyOperationErased(
Operation *op)
override;
1052 using Listener::notifyOperationReplaced;
1055 TransformOpInterface transformOp;
1091 int64_t errorCounter = 0;
1143 template <
typename OpTy>
1170 assert(region.
getParentOp() == this->getOperation() &&
1171 "op comes from the wrong region");
1178 "must indicate the region to map if the operation has more than one");
1183 class ApplyToEachResultList;
1205 template <
typename OpTy>
1248 StringRef
getName()
override {
return "transform.mapping"; }
1259 StringRef
getName()
override {
return "transform.payload_ir"; }
1292 Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1296 template <
typename OpTy>
1313 <<
"FunctionalStyleTransformOpTrait should only be attached to ops "
1314 "that implement MemoryEffectOpInterface";
1323 template <
typename OpTy>
1333 return isa<TransformHandleTypeInterface,
1334 TransformValueHandleTypeInterface>(t);
1343 op->
emitError() <<
"NavigationTransformOpTrait should only be attached "
1344 "to ops that implement MemoryEffectOpInterface";
1364 template <
typename OpTy>
1388 template <
typename OpTy>
1407 template <
typename Range>
1413 results.reserve(llvm::size(range));
1414 for (
auto element : range) {
1415 if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1417 results.push_back(
static_cast<Operation *
>(element));
1418 }
else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1420 results.push_back(element.template get<Value>());
1422 results.push_back(
static_cast<Attribute>(element));
1439 auto begin() {
return results.begin(); }
1440 auto end() {
return results.end(); }
1441 auto begin()
const {
return results.begin(); }
1442 auto end()
const {
return results.end(); }
1445 size_t size()
const {
return results.size(); }
1450 return results[index];
1464 const ApplyToEachResultList &partialResult);
1471 TransformResults &transformResults,
1485 template <
typename TransformOpTy,
typename Range>
1489 using OpTy =
typename llvm::function_traits<
1490 decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1491 static_assert(std::is_convertible<OpTy, Operation *>::value,
1492 "expected transform function to take an operation");
1496 unsigned expectedNumResults = transformOp->getNumResults();
1498 auto specificOp = dyn_cast<OpTy>(target);
1501 diag <<
"transform applied to the wrong op kind";
1502 diag.attachNote(target->getLoc()) <<
"when applied to this op";
1503 silenceableStack.push_back(std::move(
diag));
1508 partialResults.
reserve(expectedNumResults);
1509 Location specificOpLoc = specificOp->getLoc();
1512 transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1525 results.push_back(std::move(partialResults));
1527 if (!silenceableStack.empty()) {
1529 std::move(silenceableStack));
1544 template <
typename OpTy>
1549 Value handle = this->getOperation()->getOperand(0);
1550 auto targets = state.getPayloadOps(handle);
1554 if (state.getOptions().getExpensiveChecksEnabled() &&
1556 this->getOperation())) &&
1558 llvm::to_vector(targets)))) {
1559 return DiagnosedSilenceableFailure::definiteFailure();
1566 if (std::empty(targets)) {
1569 for (
OpResult r : this->getOperation()->getResults()) {
1570 if (isa<TransformParamTypeInterface>(r.getType()))
1571 transformResults.
setParams(r, emptyParams);
1572 else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1575 transformResults.
set(r, emptyPayload);
1577 return DiagnosedSilenceableFailure::success();
1584 cast<OpTy>(this->getOperation()), rewriter, targets, results, state);
1600 template <
typename OpTy>
1603 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1604 "expected single-operand op");
1606 return op->
emitError() <<
"TransformEachOpTrait should only be attached to "
1607 "ops that implement TransformOpInterface";
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
void takeDiagnostics(SmallVectorImpl< Diagnostic > &diags)
Take the diagnostics and silence.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Helper class for implementing traits.
Operation * getOperation()
Return the ultimate Operation being worked on.
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
OperationName getName()
The name of an operation is the key identifier for it.
A 2D array where each row may have different length.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This base class is used for derived effects that are non-parametric.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Region * getParentRegion()
Return the Region in which this Value is defined.
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
const FrozenRewritePatternSet GreedyRewriteConfig config
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...