9#ifndef MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
10#define MLIR_DIALECT_TRANSFORM_INTERFACES_TRANSFORMINTERFACES_H
19#include "mlir/Dialect/Transform/Interfaces/TransformAttrInterfaces.h.inc"
20#include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.h.inc"
25class TransformOpInterface;
104#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc"
121 expensiveChecksEnabled = enable;
127 enforceSingleToplevelTransformOp = enable;
136 return enforceSingleToplevelTransformOp;
140 bool expensiveChecksEnabled =
true;
141 bool enforceSingleToplevelTransformOp =
true;
153 const TransformOptions &
options = TransformOptions(),
154 bool enforceToplevelTransformOp =
true,
156 function_ref<LogicalResult(TransformState &)> stateExporter =
nullptr);
189class TransformState {
200 using TransformOpReverseMapping =
215#if LLVM_ENABLE_ABI_BREAKING_CHECKS
224 TransformOpMapping direct;
225 TransformOpReverseMapping reverse;
228 ValueMapping reverseValues;
230#if LLVM_ENABLE_ABI_BREAKING_CHECKS
231 TransformIRTimestampMapping timestamps;
232 void incrementTimestamp(
Value value) { ++timestamps[value]; }
242 friend TransformState
257 return topLevelMappedValues[position];
267#if LLVM_ENABLE_ABI_BREAKING_CHECKS
272 int64_t currentTimestamp = getMapping(value).timestamps.lookup(value);
277 return llvm::make_filter_range(view, [=](
Operation *op) {
278#if LLVM_ENABLE_ABI_BREAKING_CHECKS
279 [[maybe_unused]]
bool sameTimestamp =
280 currentTimestamp == this->getMapping(value).timestamps.lookup(value);
281 assert(sameTimestamp &&
"iterator was invalidated during iteration");
283 return op !=
nullptr;
296#if LLVM_ENABLE_ABI_BREAKING_CHECKS
302 getMapping(handleValue).timestamps.lookup(handleValue);
303 return llvm::make_filter_range(view, [=](
Value v) {
304 [[maybe_unused]]
bool sameTimestamp =
306 this->getMapping(handleValue).timestamps.lookup(handleValue);
307 assert(sameTimestamp &&
"iterator was invalidated during iteration");
311 return llvm::make_range(view.begin(), view.end());
321 bool includeOutOfScope =
false)
const;
329 bool includeOutOfScope =
false)
const;
345 "mapping block arguments from a region other than the active one");
346 return setPayloadOps(argument, operations);
379 : state(state), region(®ion) {
380 auto res = state.mappings.insert(
381 std::make_pair(®ion, std::make_unique<Mappings>()));
382 assert(res.second &&
"the region scope is already present");
384 state.regionStack.push_back(
this);
394 TransformOpInterface currentTransform;
458 TransformState &state;
467 template <
typename Ty,
typename... Args>
470 std::is_base_of<Extension, Ty>::value,
471 "only an class derived from TransformState::Extension is allowed here");
472 auto ptr = std::make_unique<Ty>(*
this, std::forward<Args>(args)...);
474 assert(
result.second &&
"extension already added");
475 return *
static_cast<Ty *
>(
result.first->second.get());
479 template <
typename Ty>
482 std::is_base_of<Extension, Ty>::value,
483 "only an class derived from TransformState::Extension is allowed here");
485 if (iter == extensions.end())
487 return static_cast<Ty *
>(iter->second.get());
491 template <
typename Ty>
494 std::is_base_of<Extension, Ty>::value,
495 "only an class derived from TransformState::Extension is allowed here");
501 static constexpr Value kTopLevelValue =
Value();
509 const TransformOptions &
options = TransformOptions());
514 const Mappings &getMapping(
Value value,
bool allowOutOfScope =
false)
const {
515 return const_cast<TransformState *
>(
this)->getMapping(value,
518 Mappings &getMapping(Value value,
bool allowOutOfScope =
false) {
520 auto it = mappings.find(region);
521 assert(it != mappings.end() &&
522 "trying to find a mapping for a value from an unmapped region");
524 if (!allowOutOfScope) {
525 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
528 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
529 llvm_unreachable(
"trying to get mapping beyond region that is "
530 "isolated from above");
540 const Mappings &getMapping(Operation *operation,
541 bool allowOutOfScope =
false)
const {
542 return const_cast<TransformState *
>(
this)->getMapping(operation,
545 Mappings &getMapping(Operation *operation,
bool allowOutOfScope =
false) {
546 Region *region = operation->getParentRegion();
547 auto it = mappings.find(region);
548 assert(it != mappings.end() &&
549 "trying to find a mapping for an operation from an unmapped region");
551 if (!allowOutOfScope) {
552 for (Region *r : llvm::reverse(llvm::make_first_range(mappings))) {
555 if (r->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
556 llvm_unreachable(
"trying to get mapping beyond region that is "
557 "isolated from above");
566 LogicalResult updateStateFromResults(
const TransformResults &results,
567 ResultRange opResults);
572 ArrayRef<Operation *> getPayloadOpsView(Value value)
const;
576 ArrayRef<Value> getPayloadValuesView(Value handleValue)
const;
613 LogicalResult setPayloadOps(Value value, ArrayRef<Operation *> targets);
654 LogicalResult setPayloadValues(Value handle,
ValueRange payloadValues);
660 LogicalResult setParams(Value value, ArrayRef<Param> params);
668 void forgetMapping(Value opHandle,
ValueRange origOpFlatResults,
669 bool allowOutOfScope =
false);
671 void forgetValueMapping(Value valueHandle,
672 ArrayRef<Operation *> payloadOperations);
680 LogicalResult replacePayloadOp(Operation *op, Operation *
replacement);
685 LogicalResult replacePayloadValue(Value value, Value
replacement);
703 void recordOpHandleInvalidation(OpOperand &consumingHandle,
704 ArrayRef<Operation *> potentialAncestors,
706 InvalidatedHandleMap &newlyInvalidated)
const;
725 void recordOpHandleInvalidationOne(
726 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
727 Operation *payloadOp, Value otherHandle, Value throughValue,
728 InvalidatedHandleMap &newlyInvalidated)
const;
745 void recordValueHandleInvalidationByOpHandleOne(
746 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
747 Value payloadValue, Value valueHandle,
748 InvalidatedHandleMap &newlyInvalidated)
const;
764 recordValueHandleInvalidation(OpOperand &valueHandle,
765 InvalidatedHandleMap &newlyInvalidated)
const;
773 checkAndRecordHandleInvalidation(TransformOpInterface transform);
778 LogicalResult checkAndRecordHandleInvalidationImpl(
779 transform::TransformOpInterface transform,
780 transform::TransformState::InvalidatedHandleMap &newlyInvalidated)
const;
783 void compactOpHandles();
790 llvm::MapVector<Region *, std::unique_ptr<Mappings>> mappings;
808 RaggedArray<MappedValue> topLevelMappedValues;
811 TransformOptions options;
817 InvalidatedHandleMap invalidatedHandles;
822 SmallVector<RegionScope *> regionStack;
826 std::unique_ptr<RegionScope> topLevelRegionScope;
831class TransformResults {
839 template <
typename Range>
842 assert(position <
static_cast<int64_t>(operations.size()) &&
843 "setting results for a non-existent handle");
844 assert(operations[position].data() ==
nullptr &&
"results already set");
845 assert(params[position].data() ==
nullptr &&
846 "another kind of results already set");
847 assert(values[position].data() ==
nullptr &&
848 "another kind of results already set");
849 operations.replace(position, std::forward<Range>(ops));
856 void set(
OpResult value, std::initializer_list<Operation *> ops) {
871 template <
typename Range>
873 int64_t position = handle.getResultNumber();
874 assert(position <
static_cast<int64_t>(this->values.size()) &&
875 "setting values for a non-existent handle");
876 assert(this->values[position].data() ==
nullptr &&
"values already set");
877 assert(operations[position].data() ==
nullptr &&
878 "another kind of results already set");
879 assert(params[position].data() ==
nullptr &&
880 "another kind of results already set");
881 this->values.replace(position, std::forward<Range>(values));
928 bool isParam(
unsigned resultNumber)
const;
933 bool isValue(
unsigned resultNumber)
const;
937 bool isSet(
unsigned resultNumber)
const;
1065 void notifyOperationErased(
Operation *op)
override;
1068 using Listener::notifyOperationReplaced;
1071 TransformOpInterface transformOp;
1118 std::optional<Diagnostic> matchFailure;
1170template <
typename OpTy>
1193 region.
front(), effects);
1200 assert(region.
getParentOp() == this->getOperation() &&
1201 "op comes from the wrong region");
1208 "must indicate the region to map if the operation has more than one");
1213class ApplyToEachResultList;
1235template <
typename OpTy>
1278 StringRef
getName()
const override {
return "transform.mapping"; }
1289 StringRef
getName()
const override {
return "transform.payload_ir"; }
1322 Block &block, llvm::SmallDenseSet<unsigned> &consumedArguments);
1326template <
typename OpTy>
1343 <<
"FunctionalStyleTransformOpTrait should only be attached to ops "
1344 "that implement MemoryEffectOpInterface";
1353template <
typename OpTy>
1363 return isa<TransformHandleTypeInterface,
1364 TransformValueHandleTypeInterface>(t);
1373 op->
emitError() <<
"NavigationTransformOpTrait should only be attached "
1374 "to ops that implement MemoryEffectOpInterface";
1394template <
typename OpTy>
1418template <
typename OpTy>
1437 template <
typename Range>
1443 results.reserve(llvm::size(range));
1444 for (
auto element : range) {
1445 if constexpr (std::is_convertible_v<
decltype(*std::begin(range)),
1447 results.push_back(
static_cast<Operation *
>(element));
1448 }
else if constexpr (std::is_convertible_v<
decltype(*std::begin(range)),
1450 results.push_back(element.template
get<Value>());
1452 results.push_back(
static_cast<Attribute>(element));
1469 auto begin() {
return results.begin(); }
1470 auto end() {
return results.end(); }
1471 auto begin()
const {
return results.begin(); }
1472 auto end()
const {
return results.end(); }
1475 size_t size()
const {
return results.size(); }
1480 return results[
index];
1494 const ApplyToEachResultList &partialResult);
1501 TransformResults &transformResults,
1515template <
typename TransformOpTy,
typename Range>
1519 using OpTy =
typename llvm::function_traits<
1520 decltype(&TransformOpTy::applyToOne)>::template arg_t<1>;
1521 static_assert(std::is_convertible<OpTy, Operation *>::value,
1522 "expected transform function to take an operation");
1526 unsigned expectedNumResults = transformOp->getNumResults();
1528 auto specificOp = dyn_cast<OpTy>(
target);
1531 diag <<
"transform applied to the wrong op kind";
1532 diag.attachNote(
target->getLoc()) <<
"when applied to this op";
1533 silenceableStack.push_back(std::move(
diag));
1538 partialResults.
reserve(expectedNumResults);
1539 Location specificOpLoc = specificOp->getLoc();
1542 transformOp.applyToOne(rewriter, specificOp, partialResults, state);
1555 results.push_back(std::move(partialResults));
1557 if (!silenceableStack.empty()) {
1559 std::move(silenceableStack));
1574template <
typename OpTy>
1588 llvm::to_vector(targets)))) {
1596 if (std::empty(targets)) {
1600 if (isa<TransformParamTypeInterface>(r.getType()))
1601 transformResults.
setParams(r, emptyParams);
1602 else if (isa<TransformValueHandleTypeInterface>(r.getType()))
1605 transformResults.
set(r, emptyPayload);
1614 cast<OpTy>(this->
getOperation()), rewriter, targets, results, state);
1617 if (
result.isDefiniteFailure())
1630template <
typename OpTy>
1633 static_assert(OpTy::template hasTrait<OpTrait::OneOperand>(),
1634 "expected single-operand op");
1636 return op->
emitError() <<
"TransformEachOpTrait should only be attached to "
1637 "ops that implement TransformOpInterface";
1645struct PointerLikeTypeTraits<
mlir::transform::NormalFormAttrInterface>
1646 :
public PointerLikeTypeTraits<mlir::Attribute> {
1647 static inline mlir::transform::NormalFormAttrInterface
1649 return cast<mlir::transform::NormalFormAttrInterface>(
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
static Attribute getFromOpaquePointer(const void *ptr)
Construct an attribute from the opaque pointer representation.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
MutableArrayRef< BlockArgument > BlockArgListType
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
void takeDiagnostics(SmallVectorImpl< Diagnostic > &diags)
Take the diagnostics and silence.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Helper class for implementing traits.
Operation * getOperation()
T::Concept * getInterface() const
Returns an instance of the concept object for the given interface if it was registered to this operat...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
A 2D array where each row may have different length.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
Operation * getParentOp()
Return the parent operation this region is attached to.
This class implements the result iterators for the Operation class.
RewriterBase(MLIRContext *ctx, OpBuilder::Listener *listener=nullptr)
Initialize the builder.
This base class is used for derived effects that are non-parametric.
static TypeID get()
Construct a type info object for the given type T.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Region * getParentRegion()
Return the Region in which this Value is defined.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...