13 #include "llvm/Support/Debug.h" 15 #define DEBUG_TYPE "transform-dialect" 16 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 24 constexpr
const Value transform::TransformState::kTopLevelValue;
28 : topLevel(root), options(options) {
29 auto result = mappings.try_emplace(®ion);
30 assert(result.second &&
"the region scope is already present");
32 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 33 regionStack.push_back(®ion);
34 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 42 auto iter = operationMapping.find(value);
43 assert(iter != operationMapping.end() &&
"unknown handle");
44 return iter->getSecond();
48 for (
const Mappings &mapping : llvm::make_second_range(mappings)) {
49 if (
Value handle = mapping.reverse.lookup(op))
55 LogicalResult transform::TransformState::tryEmplaceReverseMapping(
57 auto insertionResult = map.reverse.insert({operation, handle});
58 if (!insertionResult.second && insertionResult.first->second != handle) {
60 <<
"operation tracked by two handles";
62 diag.
attachNote(insertionResult.first->second.getLoc()) <<
"handle";
69 transform::TransformState::setPayloadOps(
Value value,
71 assert(value != kTopLevelValue &&
72 "attempting to reset the transformation root");
80 Mappings &mappings = getMapping(value);
82 mappings.direct.insert({
value, std::move(storedTargets)}).second;
83 assert(inserted &&
"value is already associated with another list");
90 if (
failed(tryEmplaceReverseMapping(mappings, op, value)))
97 void transform::TransformState::removePayloadOps(
Value value) {
98 Mappings &mappings = getMapping(value);
99 for (
Operation *op : mappings.direct[value])
100 mappings.reverse.erase(op);
101 mappings.direct.erase(value);
106 Mappings &mappings = getMapping(value);
107 auto it = mappings.direct.find(value);
108 assert(it != mappings.direct.end() &&
"unknown handle");
111 updated.reserve(association.size());
114 mappings.reverse.erase(op);
115 if (
Operation *updatedOp = callback(op)) {
116 updated.push_back(updatedOp);
117 if (
failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
122 std::swap(association, updated);
126 void transform::TransformState::recordHandleInvalidation(
OpOperand &handle) {
128 for (
const Mappings &mapping : llvm::make_second_range(mappings)) {
129 for (
const auto &kvp : mapping.reverse) {
133 Value otherHandle = kvp.second;
134 if (invalidatedHandles.count(otherHandle))
137 for (
Operation *ancestor : potentialAncestors) {
138 if (!ancestor->isProperAncestor(op))
145 Location ancestorLoc = ancestor->getLoc();
149 invalidatedHandles[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
153 <<
"invalidated the handle to payload operations nested in the " 154 "payload operation associated with its operand #" 156 diag.
attachNote(ancestorLoc) <<
"ancestor op";
165 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
166 TransformOpInterface transform) {
167 auto memoryEffectsIface =
168 cast<MemoryEffectOpInterface>(transform.getOperation());
170 memoryEffectsIface.getEffectsOnResource(
173 for (
OpOperand &target : transform->getOpOperands()) {
175 auto it = invalidatedHandles.find(target.get());
176 if (it != invalidatedHandles.end())
177 return it->getSecond()(),
failure();
182 return isa<MemoryEffects::Free>(effect.getEffect()) &&
183 effect.getValue() == target.get();
185 if (llvm::any_of(effects, consumesTarget))
186 recordHandleInvalidation(target);
193 LLVM_DEBUG(
DBGS() <<
"applying: " << transform <<
"\n");
195 if (
failed(checkAndRecordHandleInvalidation(transform)))
198 for (
OpOperand &operand : transform->getOpOperands()) {
204 if (!seen.insert(op).second) {
206 transform.emitSilenceableError()
207 <<
"a handle passed as operand #" << operand.getOperandNumber()
208 <<
" and consumed by this operation points to a payload " 209 "operation more than once";
210 diag.
attachNote(op->getLoc()) <<
"repeated target op";
219 if (!result.succeeded())
224 auto memEffectInterface =
225 cast<MemoryEffectOpInterface>(transform.getOperation());
227 for (
OpOperand &target : transform->getOpOperands()) {
229 memEffectInterface.getEffectsOnValue(target.get(), effects);
231 return isa<transform::TransformMappingResource>(
233 isa<MemoryEffects::Free>(effect.
getEffect());
235 removePayloadOps(target.get());
239 for (
OpResult result : transform->getResults()) {
240 assert(result.getDefiningOp() == transform.getOperation() &&
241 "payload IR association for a value other than the result of the " 242 "current transform op");
243 if (
failed(setPayloadOps(result, results.get(result.getResultNumber()))))
259 return state.updatePayloadOps(state.getHandleForPayloadOp(op),
261 return current == op ? replacement : current;
269 transform::TransformResults::TransformResults(
unsigned numSegments) {
270 segments.resize(numSegments,
277 assert(position < segments.size() &&
278 "setting results for a non-existent handle");
279 assert(segments[position].data() ==
nullptr &&
"results already set");
280 unsigned start = operations.size();
281 llvm::append_range(operations, ops);
282 segments[position] = makeArrayRef(operations).drop_front(start);
286 transform::TransformResults::get(
unsigned resultNumber)
const {
287 assert(resultNumber < segments.size() &&
288 "querying results for a non-existent handle");
289 assert(segments[resultNumber].data() !=
nullptr &&
"querying unset results");
290 return segments[resultNumber];
313 assert(isa<TransformOpInterface>(op) &&
314 "should implement TransformOpInterface to have " 315 "PossibleTopLevelTransformOpTrait");
318 return op->
emitOpError() <<
"expects at least one region";
321 if (!llvm::hasNItems(*bodyRegion, 1))
322 return op->
emitOpError() <<
"expects a single-block region";
328 <<
"expects the entry block to have one argument of type " 337 <<
"expects the root operation to be provided for a nested op";
339 <<
"nested in another possible top-level op";
354 for (
Value handle : handles) {
364 template <
typename EffectTy,
typename ResourceTy = S
ideEffects::DefaultResource>
367 return isa<EffectTy>(effect.
getEffect()) &&
373 transform::TransformOpInterface transform) {
374 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
376 iface.getEffectsOnValue(handle, effects);
377 return hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
378 hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
384 for (
Value handle : handles) {
395 for (
Value handle : handles) {
416 #include "mlir/Dialect/Transform/IR/TransformInterfaces.cpp.inc" Diagnostic & attachNote(Optional< Location > loc=llvm::None)
Attaches a note to the last diagnostic.
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
static std::string diag(llvm::Value &v)
The result of a transform IR operation application.
Operation is a basic unit of execution within MLIR.
This is a value defined by a result of an operation.
EffectT * getEffect() const
Return the effect being applied.
unsigned getNumRegions()
Returns the number of regions held by this operation.
This class represents a diagnostic that is inflight and set to be reported.
Block represents an ordered list of Operations.
Value getOperand(unsigned idx)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
unsigned getNumOperands()
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
BlockArgument getArgument(unsigned i)
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext * getContext()
Return the context this operation is associated with.
static TransformMappingResource * get()
Returns a unique instance for the given effect class.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(Optional< Location > noteLoc=llvm::None)
Attaches a note to this diagnostic.
unsigned getNumArguments()
unsigned getResultNumber() const
Returns the number of this result.
Location getLoc()
The source location the operation was defined or derived from.
IRValueT get() const
Return the current value being used by this operand.
This class represents a specific instance of an effect.
Location getLoc() const
Return the location of this value.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
static llvm::ManagedStatic< PassManagerOptions > options
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Resource * getResource() const
Return the resource that the effect applies to.
Operation * getOwner() const
Return the owner of this operand.
This class represents an operand of an operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
This class provides an abstraction over the different types of ranges over Values.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.