9 #ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
10 #define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
19 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"
24 class TransformOpInterface;
25 class TransformResults;
26 class TransformRewriter;
88 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"
105 expensiveChecksEnabled = enable;
111 enforceSingleToplevelTransformOp = enable;
120 return enforceSingleToplevelTransformOp;
124 bool expensiveChecksEnabled =
true;
125 bool enforceSingleToplevelTransformOp =
true;
137 const TransformOptions &
options = TransformOptions(),
138 bool enforceToplevelTransformOp =
true);
197 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
212 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
213 TransformIRTimestampMapping timestamps;
214 void incrementTimestamp(
Value value) { ++timestamps[value]; }
237 return topLevelMappedValues[position];
247 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
252 int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
257 return llvm::make_filter_range(view, [=](
Operation *op) {
258 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
259 [[maybe_unused]]
bool sameTimestamp =
260 currentTimestamp == this->getMapping(value).timestamps.lookup(value);
261 assert(sameTimestamp &&
"iterator was invalidated during iteration");
263 return op !=
nullptr;
276 #ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
281 int64_t currentTimestamp =
282 getMapping(handleValue).timestamps.lookup(handleValue);
283 return llvm::make_filter_range(view, [=](
Value v) {
284 [[maybe_unused]]
bool sameTimestamp =
286 this->getMapping(handleValue).timestamps.lookup(handleValue);
287 assert(sameTimestamp &&
"iterator was invalidated during iteration");
291 return llvm::make_range(view.begin(), view.end());
301 bool includeOutOfScope =
false)
const;
309 bool includeOutOfScope =
false)
const;
325 "mapping block arguments from a region other than the active one");
326 return setPayloadOps(argument, operations);
359 : state(state), region(®ion) {
360 auto res = state.mappings.insert(
361 std::make_pair(®ion, std::make_unique<Mappings>()));
362 assert(res.second &&
"the region scope is already present");
364 state.regionStack.push_back(
this);
374 TransformOpInterface currentTransform;
447 template <
typename Ty,
typename... Args>
450 std::is_base_of<Extension, Ty>::value,
451 "only an class derived from TransformState::Extension is allowed here");
452 auto ptr = std::make_unique<Ty>(*
this, std::forward<Args>(args)...);
453 auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
454 assert(result.second &&
"extension already added");
455 return *
static_cast<Ty *
>(result.first->second.get());
459 template <
typename Ty>
462 std::is_base_of<Extension, Ty>::value,
463 "only an class derived from TransformState::Extension is allowed here");
464 auto iter = extensions.find(TypeID::get<Ty>());
465 if (iter == extensions.end())
467 return static_cast<Ty *
>(iter->second.get());
471 template <
typename Ty>
474 std::is_base_of<Extension, Ty>::value,
475 "only an class derived from TransformState::Extension is allowed here");
476 extensions.erase(TypeID::get<Ty>());
481 static constexpr
Value kTopLevelValue =
Value();
489 const TransformOptions &
options = TransformOptions());
494 const Mappings &getMapping(
Value value,
bool allowOutOfScope =
false)
const {
495 return const_cast<TransformState *
>(
this)->getMapping(value,
498 Mappings &getMapping(Value value,
bool allowOutOfScope =
false) {
500 auto it = mappings.find(region);
501 assert(it != mappings.end() &&
502 "trying to find a mapping for a value from an unmapped region");
504 if (!allowOutOfScope) {
505 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
508 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
509 llvm_unreachable(
"trying to get mapping beyond region that is "
510 "isolated from above");
520 const Mappings &getMapping(Operation *operation,
521 bool allowOutOfScope =
false)
const {
522 return const_cast<TransformState *
>(
this)->getMapping(operation,
525 Mappings &getMapping(Operation *operation,
bool allowOutOfScope =
false) {
526 Region *region = operation->getParentRegion();
527 auto it = mappings.find(region);
528 assert(it != mappings.end() &&
529 "trying to find a mapping for an operation from an unmapped region");
531 if (!allowOutOfScope) {
532 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
535 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
536 llvm_unreachable(
"trying to get mapping beyond region that is "
537 "isolated from above");
546 LogicalResult updateStateFromResults(
const TransformResults &results,
547 ResultRange opResults);
552 ArrayRef<Operation *> getPayloadOpsView(Value value)
const;
556 ArrayRef<Value> getPayloadValuesView(Value handleValue)
const;
593 LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
634 LogicalResult setPayloadValues(Value handle, ValueRange payloadValues);
640 LogicalResult setParams(Value value, ArrayRef<Param> params);
648 void forgetMapping(Value opHandle, ValueRange origOpFlatResults,
649 bool allowOutOfScope =
false);
651 void forgetValueMapping(Value valueHandle,
652 ArrayRef<Operation *> payloadOperations);
660 LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
665 LogicalResult replacePayloadValue(Value value, Value replacement);
683 void recordOpHandleInvalidation(OpOperand &consumingHandle,
684 ArrayRef<Operation *> potentialAncestors,
686 InvalidatedHandleMap &newlyInvalidated)
const;
705 void recordOpHandleInvalidationOne(
706 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
707 Operation *payloadOp, Value otherHandle, Value throughValue,
708 InvalidatedHandleMap &newlyInvalidated)
const;
725 void recordValueHandleInvalidationByOpHandleOne(
726 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
727 Value payloadValue, Value valueHandle,
728 InvalidatedHandleMap &newlyInvalidated)
const;
744 recordValueHandleInvalidation(OpOperand &valueHandle,
745 InvalidatedHandleMap &newlyInvalidated)
const;
753 checkAndRecordHandleInvalidation(TransformOpInterface transform);
758 LogicalResult checkAndRecordHandleInvalidationImpl(
759 transform::TransformOpInterface transform,
760 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const;
763 void compactOpHandles();
770 llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
776 DenseSet<Value> opHandlesToCompact;
780 DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
788 RaggedArray<MappedValue> topLevelMappedValues;
797 InvalidatedHandleMap invalidatedHandles;
802 SmallVector<RegionScope *> regionStack;
806 std::unique_ptr<RegionScope> topLevelRegionScope;
819 template <
typename Range>
822 assert(position <
static_cast<int64_t
>(operations.size()) &&
823 "setting results for a non-existent handle");
824 assert(operations[position].data() ==
nullptr &&
"results already set");
825 assert(params[position].data() ==
nullptr &&
826 "another kind of results already set");
827 assert(values[position].data() ==
nullptr &&
828 "another kind of results already set");
829 operations.replace(position, std::forward<Range>(ops));
836 void set(
OpResult value, std::initializer_list<Operation *> ops) {
851 template <
typename Range>
854 assert(position <
static_cast<int64_t
>(this->values.size()) &&
855 "setting values for a non-existent handle");
856 assert(this->values[position].data() ==
nullptr &&
"values already set");
857 assert(operations[position].data() ==
nullptr &&
858 "another kind of results already set");
859 assert(params[position].data() ==
nullptr &&
860 "another kind of results already set");
861 this->values.replace(position, std::forward<Range>(values));
908 bool isParam(
unsigned resultNumber)
const;
913 bool isValue(
unsigned resultNumber)
const;
917 bool isSet(
unsigned resultNumber)
const;
1045 void notifyOperationErased(
Operation *op)
override;
1048 using Listener::notifyOperationReplaced;
1051 TransformOpInterface transformOp;
1087 int64_t errorCounter = 0;
1139 template <
typename OpTy>
1166 assert(region.
getParentOp() == this->getOperation() &&
1167 "op comes from the wrong region");
1174 "must indicate the region to map if the operation has more than one");
1179 class ApplyToEachResultList;
1201 template <
typename OpTy>
1244 StringRef
getName()
override {
return "transform.mapping"; }
1255 StringRef
getName()
override {
return "transform.payload_ir"; }
1288 Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1292 template <
typename OpTy>
1307 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1309 <<
"FunctionalStyleTransformOpTrait should only be attached to ops "
1310 "that implement MemoryEffectOpInterface";
1319 template <
typename OpTy>
1329 return isa<TransformHandleTypeInterface,
1330 TransformValueHandleTypeInterface>(t);
1338 if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1339 op->emitError() <<
"NavigationTransformOpTrait should only be attached "
1340 "to ops that implement MemoryEffectOpInterface";
1360 template <
typename OpTy>
1384 template <
typename OpTy>
1403 template <
typename Range>
1409 results.reserve(llvm::size(range));
1410 for (
auto element : range) {
1411 if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1413 results.push_back(
static_cast<Operation *
>(element));
1414 }
else if constexpr (std::is_convertible_v<decltype(*std::begin(range)),
1416 results.push_back(element.template get<Value>());
1418 results.push_back(
static_cast<Attribute>(element));
1435 auto begin() {
return results.begin(); }
1436 auto end() {
return results.end(); }
1437 auto begin()
const {
return results.begin(); }
1438 auto end()
const {
return results.end(); }
1441 size_t size()
const {
return results.size(); }
1446 return results[index];
1460 const ApplyToEachResultList &partialResult);
1467 TransformResults &transformResults,
1481 template <
typename TransformOpTy,
typename Range>
1485 using OpTy =
typename llvm::function_traits<
1486 decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1487 static_assert(std::is_convertible<OpTy, Operation *>::value,
1488 "expected transform function to take an operation");
1492 unsigned expectedNumResults = transformOp->getNumResults();
1494 auto specificOp = dyn_cast<OpTy>(target);
1497 diag <<
"transform applied to the wrong op kind";
1498 diag.attachNote(target->getLoc()) <<
"when applied to this op";
1499 silenceableStack.push_back(std::move(
diag));
1504 partialResults.
reserve(expectedNumResults);
1505 Location specificOpLoc = specificOp->getLoc();
1508 transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1521 results.push_back(std::move(partialResults));
1523 if (!silenceableStack.empty()) {
1525 std::move(silenceableStack));
1540 template <
typename OpTy>
1545 Value handle = this->getOperation()->getOperand(0);
1546 auto targets = state.getPayloadOps(handle);
1550 if (state.getOptions().getExpensiveChecksEnabled() &&
1552 this->getOperation())) &&
1554 llvm::to_vector(targets)))) {
1555 return DiagnosedSilenceableFailure::definiteFailure();
1562 if (std::empty(targets)) {
1565 for (
OpResult r : this->getOperation()->getResults()) {
1566 if (isa<TransformParamTypeInterface>(r.getType()))
1567 transformResults.
setParams(r, emptyParams);
1568 else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1571 transformResults.
set(r, emptyPayload);
1573 return DiagnosedSilenceableFailure::success();
1580 cast<OpTy>(this->getOperation()), rewriter, targets, results, state);
1596 template <
typename OpTy>
1599 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1600 "expected single-operand op");
1601 if (!op->getName().getInterface<TransformOpInterface>()) {
1602 return op->emitError() <<
"TransformEachOpTrait should only be attached to "
1603 "ops that implement TransformOpInterface";
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
void takeDiagnostics(SmallVectorImpl< Diagnostic > &diags)
Take the diagnostics and silence.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Helper class for implementing traits.
Operation * getOperation()
Return the ultimate Operation being worked on.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
A 2D array where each row may have different length.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This base class is used for derived effects that are non-parametric.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Region * getParentRegion()
Return the Region in which this Value is defined.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...