34 void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
39 void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
44 void transform::ApplySCFStructuralConversionPatternsOp::
45 populateConversionTargetRules(
const TypeConverter &typeConverter,
51 void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
64 auto payload = state.getPayloadOps(getTarget());
65 if (!llvm::hasSingleElement(payload))
66 return emitSilenceableError() <<
"expected a single payload op";
68 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
71 emitSilenceableError() <<
"expected the payload to be scf.forall";
72 diag.attachNote((*payload.begin())->getLoc()) <<
"payload op";
76 if (!target.getOutputs().empty()) {
77 return emitSilenceableError()
78 <<
"unsupported shared outputs (didn't bufferize?)";
83 if (getNumResults() != lbs.size()) {
85 emitSilenceableError()
86 <<
"op expects as many results (" << getNumResults()
87 <<
") as payload has induction variables (" << lbs.size() <<
")";
88 diag.attachNote(target.getLoc()) <<
"payload op";
95 <<
"failed to convert forall into for";
100 results.
set(cast<OpResult>(getTransformed()[i]), {res});
113 auto payload = state.getPayloadOps(getTarget());
114 if (!llvm::hasSingleElement(payload))
115 return emitSilenceableError() <<
"expected a single payload op";
117 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
120 emitSilenceableError() <<
"expected the payload to be scf.forall";
121 diag.attachNote((*payload.begin())->getLoc()) <<
"payload op";
125 if (!target.getOutputs().empty()) {
126 return emitSilenceableError()
127 <<
"unsupported shared outputs (didn't bufferize?)";
130 if (getNumResults() != 1) {
132 <<
"op expects one result, given "
134 diag.attachNote(target.getLoc()) <<
"payload op";
138 scf::ParallelOp opResult;
141 emitSilenceableError() <<
"failed to convert forall into parallel";
145 results.
set(cast<OpResult>(getTransformed()[0]), {opResult});
162 scf::ExecuteRegionOp executeRegionOp =
169 assert(clonedRegion.
empty() &&
"expected empty region");
175 return executeRegionOp;
185 for (
Operation *target : state.getPayloadOps(getTarget())) {
186 Location location = target->getLoc();
191 <<
"failed to outline";
192 diag.attachNote(target->getLoc()) <<
"target op";
197 rewriter, location, exec.getRegion(), getFuncName(), &call);
199 if (failed(outlined))
200 return emitDefaultDefiniteFailure(target);
204 symbolTables.try_emplace(symbolTableOp, symbolTableOp)
206 symbolTable.
insert(*outlined);
209 functions.push_back(*outlined);
210 calls.push_back(call);
212 results.
set(cast<OpResult>(getFunction()), functions);
213 results.
set(cast<OpResult>(getCall()), calls);
227 if (getPeelFront()) {
228 LogicalResult status =
230 if (failed(status)) {
232 emitSilenceableError() <<
"failed to peel the first iteration";
236 LogicalResult status =
238 if (failed(status)) {
240 <<
"failed to peel the last iteration";
260 std::vector<std::pair<Operation *, unsigned>> &schedule,
261 unsigned iterationInterval,
unsigned readLatency) {
262 auto getLatency = [&](
Operation *op) ->
unsigned {
263 if (isa<vector::TransferReadOp>(op))
268 std::optional<int64_t> ubConstant =
270 std::optional<int64_t> lbConstant =
273 std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
274 for (
Operation &op : forOp.getBody()->getOperations()) {
275 if (isa<scf::YieldOp>(op))
277 unsigned earlyCycle = 0;
278 for (
Value operand : op.getOperands()) {
279 Operation *def = operand.getDefiningOp();
282 if (ubConstant && lbConstant) {
283 unsigned ubInt = ubConstant.value();
284 unsigned lbInt = lbConstant.value();
285 auto minLatency =
std::min(ubInt - lbInt - 1, getLatency(def));
286 earlyCycle =
std::max(earlyCycle, opCycles[def] + minLatency);
288 earlyCycle =
std::max(earlyCycle, opCycles[def] + getLatency(def));
291 opCycles[&op] = earlyCycle;
292 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
294 for (
const auto &it : wrappedSchedule) {
296 unsigned cycle = opCycles[op];
297 schedule.emplace_back(op, cycle / iterationInterval);
309 [
this](scf::ForOp forOp,
310 std::vector<std::pair<Operation *, unsigned>> &schedule)
mutable {
316 FailureOr<scf::ForOp> patternResult =
318 if (succeeded(patternResult)) {
322 return emitDefaultSilenceableFailure(target);
333 (void)target.promoteIfSingleIteration(rewriter);
337 void transform::LoopPromoteIfOneIterationOp::getEffects(
352 LogicalResult result(failure());
353 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
355 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
358 return emitSilenceableError()
359 <<
"failed to unroll, incorrect type of payload";
362 return emitSilenceableError() <<
"failed to unroll";
375 LogicalResult result(failure());
376 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
378 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
381 return emitSilenceableError()
382 <<
"failed to unroll and jam, incorrect type of payload";
385 return emitSilenceableError() <<
"failed to unroll and jam";
399 LogicalResult result(failure());
400 if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
402 else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
406 if (failed(result)) {
408 <<
"failed to coalesce";
421 assert(region.
hasOneBlock() &&
"expected single-block region");
436 getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
439 <<
"requires an scf.if op with a single-block "
440 << ((getTakeElseBranch()) ?
"`else`" :
"`then`") <<
" region";
446 void transform::TakeAssumedBranchOp::getEffects(
465 if (target == source)
467 <<
"target and source need to be different loops";
472 <<
"target and source are not in the same block";
482 <<
"user of results of target should be properly dominated by "
492 Operation *operandOp = operand.getDefiningOp();
502 <<
"operands of target should be properly dominated by source";
509 Operation *operandOp = operand->get().getDefiningOp();
510 if (operandOp && !domInfo.properlyDominates(operandOp, source,
515 failedValue = operand;
521 <<
"values used inside regions of target should be properly "
522 "dominated by source";
535 auto targetOp = dyn_cast<scf::ForallOp>(target);
536 auto sourceOp = dyn_cast<scf::ForallOp>(source);
537 if (!targetOp || !sourceOp)
540 return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
541 targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
542 targetOp.getMixedStep() == sourceOp.getMixedStep() &&
543 targetOp.getMapping() == sourceOp.getMapping();
553 auto targetOp = dyn_cast<scf::ForOp>(target);
554 auto sourceOp = dyn_cast<scf::ForOp>(source);
555 if (!targetOp || !sourceOp)
558 return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
559 targetOp.getUpperBound() == sourceOp.getUpperBound() &&
560 targetOp.getStep() == sourceOp.getStep();
567 auto targetOps = state.getPayloadOps(getTarget());
568 auto sourceOps = state.getPayloadOps(getSource());
570 if (!llvm::hasSingleElement(targetOps) ||
571 !llvm::hasSingleElement(sourceOps)) {
573 <<
"requires exactly one target handle (got "
574 << llvm::range_size(targetOps) <<
") and exactly one "
575 <<
"source handle (got " << llvm::range_size(sourceOps) <<
")";
583 if (!
diag.succeeded())
590 cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
593 cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
596 <<
"operations cannot be fused";
599 assert(fusedLoop &&
"failed to fuse operations");
601 results.
set(cast<OpResult>(getFusedLoop()), {fusedLoop});
610 class SCFTransformDialectExtension
612 SCFTransformDialectExtension> {
619 declareGeneratedDialect<affine::AffineDialect>();
620 declareGeneratedDialect<func::FuncDialect>();
622 registerTransformOps<
624 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
630 #define GET_OP_CLASSES
631 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class describes a specific conversion target.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, bool cleanUpUnroll=false)
Unrolls this for operation by the specified unroll factor.
LogicalResult loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor)
Unrolls and jams this loop by the specified factor.
LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op)
Walk an affine.for to find a band to coalesce.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerTransformDialectExtension(DialectRegistry ®istry)
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp, SmallVectorImpl< Operation * > *results=nullptr)
Try converting scf.forall into a set of nested scf.for loops.
LogicalResult peelForLoopFirstIteration(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Peel the first iteration out of the scf.for loop.
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Options to dictate how loops should be pipelined.