34 void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
39 void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
44 void transform::ApplySCFStructuralConversionPatternsOp::
45 populateConversionTargetRules(
const TypeConverter &typeConverter,
51 void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
64 auto payload = state.getPayloadOps(getTarget());
65 if (!llvm::hasSingleElement(payload))
66 return emitSilenceableError() <<
"expected a single payload op";
68 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
71 emitSilenceableError() <<
"expected the payload to be scf.forall";
72 diag.attachNote((*payload.begin())->getLoc()) <<
"payload op";
76 if (!target.getOutputs().empty()) {
77 return emitSilenceableError()
78 <<
"unsupported shared outputs (didn't bufferize?)";
83 if (getNumResults() != lbs.size()) {
85 emitSilenceableError()
86 <<
"op expects as many results (" << getNumResults()
87 <<
") as payload has induction variables (" << lbs.size() <<
")";
88 diag.attachNote(target.getLoc()) <<
"payload op";
95 <<
"failed to convert forall into for";
100 results.
set(cast<OpResult>(getTransformed()[i]), {res});
113 auto payload = state.getPayloadOps(getTarget());
114 if (!llvm::hasSingleElement(payload))
115 return emitSilenceableError() <<
"expected a single payload op";
117 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
120 emitSilenceableError() <<
"expected the payload to be scf.forall";
121 diag.attachNote((*payload.begin())->getLoc()) <<
"payload op";
125 if (!target.getOutputs().empty()) {
126 return emitSilenceableError()
127 <<
"unsupported shared outputs (didn't bufferize?)";
130 if (getNumResults() != 1) {
132 <<
"op expects one result, given "
134 diag.attachNote(target.getLoc()) <<
"payload op";
138 scf::ParallelOp opResult;
141 emitSilenceableError() <<
"failed to convert forall into parallel";
145 results.
set(cast<OpResult>(getTransformed()[0]), {opResult});
156 auto payload = state.getPayloadOps(getTarget());
157 if (!llvm::hasSingleElement(payload))
158 return emitSilenceableError() <<
"expected a single payload op";
160 auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
163 emitSilenceableError() <<
"expected the payload to be scf.parallel";
164 diag.attachNote((*payload.begin())->getLoc()) <<
"payload op";
168 if (getNumResults() != 1) {
170 <<
"op expects one result, given "
172 diag.attachNote(target.getLoc()) <<
"payload op";
176 FailureOr<scf::LoopNest> loopNest =
180 emitSilenceableError() <<
"failed to convert parallel into nested fors";
184 results.
set(cast<OpResult>(getTransformed()[0]), {loopNest->loops.front()});
201 scf::ExecuteRegionOp executeRegionOp =
208 assert(clonedRegion.
empty() &&
"expected empty region");
214 return executeRegionOp;
224 for (
Operation *target : state.getPayloadOps(getTarget())) {
225 Location location = target->getLoc();
230 <<
"failed to outline";
231 diag.attachNote(target->getLoc()) <<
"target op";
236 rewriter, location, exec.getRegion(), getFuncName(), &call);
239 return emitDefaultDefiniteFailure(target);
243 symbolTables.try_emplace(symbolTableOp, symbolTableOp)
245 symbolTable.
insert(*outlined);
248 functions.push_back(*outlined);
249 calls.push_back(call);
251 results.
set(cast<OpResult>(getFunction()), functions);
252 results.
set(cast<OpResult>(getCall()), calls);
266 if (getPeelFront()) {
267 LogicalResult status =
271 emitSilenceableError() <<
"failed to peel the first iteration";
275 LogicalResult status =
279 <<
"failed to peel the last iteration";
299 std::vector<std::pair<Operation *, unsigned>> &schedule,
300 unsigned iterationInterval,
unsigned readLatency) {
301 auto getLatency = [&](
Operation *op) ->
unsigned {
302 if (isa<vector::TransferReadOp>(op))
307 std::optional<int64_t> ubConstant =
309 std::optional<int64_t> lbConstant =
312 std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
313 for (
Operation &op : forOp.getBody()->getOperations()) {
314 if (isa<scf::YieldOp>(op))
316 unsigned earlyCycle = 0;
317 for (
Value operand : op.getOperands()) {
318 Operation *def = operand.getDefiningOp();
321 if (ubConstant && lbConstant) {
322 unsigned ubInt = ubConstant.value();
323 unsigned lbInt = lbConstant.value();
324 auto minLatency =
std::min(ubInt - lbInt - 1, getLatency(def));
325 earlyCycle =
std::max(earlyCycle, opCycles[def] + minLatency);
327 earlyCycle =
std::max(earlyCycle, opCycles[def] + getLatency(def));
330 opCycles[&op] = earlyCycle;
331 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
333 for (
const auto &it : wrappedSchedule) {
335 unsigned cycle = opCycles[op];
336 schedule.emplace_back(op, cycle / iterationInterval);
348 [
this](scf::ForOp forOp,
349 std::vector<std::pair<Operation *, unsigned>> &schedule)
mutable {
355 FailureOr<scf::ForOp> patternResult =
357 if (succeeded(patternResult)) {
361 return emitDefaultSilenceableFailure(target);
372 (void)target.promoteIfSingleIteration(rewriter);
376 void transform::LoopPromoteIfOneIterationOp::getEffects(
391 LogicalResult result(failure());
392 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
394 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
397 return emitSilenceableError()
398 <<
"failed to unroll, incorrect type of payload";
401 return emitSilenceableError() <<
"failed to unroll";
414 LogicalResult result(failure());
415 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
417 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
420 return emitSilenceableError()
421 <<
"failed to unroll and jam, incorrect type of payload";
424 return emitSilenceableError() <<
"failed to unroll and jam";
438 LogicalResult result(failure());
439 if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
441 else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
447 <<
"failed to coalesce";
460 assert(region.
hasOneBlock() &&
"expected single-block region");
475 getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
478 <<
"requires an scf.if op with a single-block "
479 << ((getTakeElseBranch()) ?
"`else`" :
"`then`") <<
" region";
485 void transform::TakeAssumedBranchOp::getEffects(
504 if (target == source)
506 <<
"target and source need to be different loops";
511 <<
"target and source are not in the same block";
521 <<
"user of results of target should be properly dominated by "
531 Operation *operandOp = operand.getDefiningOp();
541 <<
"operands of target should be properly dominated by source";
548 Operation *operandOp = operand->get().getDefiningOp();
549 if (operandOp && !domInfo.properlyDominates(operandOp, source,
554 failedValue = operand;
560 <<
"values used inside regions of target should be properly "
561 "dominated by source";
574 auto targetOp = dyn_cast<scf::ForallOp>(target);
575 auto sourceOp = dyn_cast<scf::ForallOp>(source);
576 if (!targetOp || !sourceOp)
579 return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
580 targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
581 targetOp.getMixedStep() == sourceOp.getMixedStep() &&
582 targetOp.getMapping() == sourceOp.getMapping();
592 auto targetOp = dyn_cast<scf::ForOp>(target);
593 auto sourceOp = dyn_cast<scf::ForOp>(source);
594 if (!targetOp || !sourceOp)
597 return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
598 targetOp.getUpperBound() == sourceOp.getUpperBound() &&
599 targetOp.getStep() == sourceOp.getStep();
606 auto targetOps = state.getPayloadOps(getTarget());
607 auto sourceOps = state.getPayloadOps(getSource());
609 if (!llvm::hasSingleElement(targetOps) ||
610 !llvm::hasSingleElement(sourceOps)) {
612 <<
"requires exactly one target handle (got "
613 << llvm::range_size(targetOps) <<
") and exactly one "
614 <<
"source handle (got " << llvm::range_size(sourceOps) <<
")";
622 if (!
diag.succeeded())
629 cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
632 cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
635 <<
"operations cannot be fused";
638 assert(fusedLoop &&
"failed to fuse operations");
640 results.
set(cast<OpResult>(getFusedLoop()), {fusedLoop});
649 class SCFTransformDialectExtension
651 SCFTransformDialectExtension> {
658 declareGeneratedDialect<affine::AffineDialect>();
659 declareGeneratedDialect<func::FuncDialect>();
661 registerTransformOps<
663 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
669 #define GET_OP_CLASSES
670 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class describes a specific conversion target.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, bool cleanUpUnroll=false)
Unrolls this for operation by the specified unroll factor.
LogicalResult loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor)
Unrolls and jams this loop by the specified factor.
LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op)
Walk an affine.for to find a band to coalesce.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerTransformDialectExtension(DialectRegistry ®istry)
FailureOr< scf::LoopNest > parallelForToNestedFors(RewriterBase &rewriter, ParallelOp parallelOp)
Try converting scf.forall into an scf.parallel loop.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp, SmallVectorImpl< Operation * > *results=nullptr)
Try converting scf.forall into a set of nested scf.for loops.
LogicalResult peelForLoopFirstIteration(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Peel the first iteration out of the scf.for loop.
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Options to dictate how loops should be pipelined.