37 void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
42 void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
47 void transform::ApplySCFStructuralConversionPatternsOp::
48 populateConversionTargetRules(
const TypeConverter &typeConverter,
54 void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
67 auto payload = state.getPayloadOps(getTarget());
68 if (!llvm::hasSingleElement(payload))
69 return emitSilenceableError() <<
"expected a single payload op";
71 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
74 emitSilenceableError() <<
"expected the payload to be scf.forall";
75 diag.attachNote((*payload.begin())->getLoc()) <<
"payload op";
79 if (!target.getOutputs().empty()) {
80 return emitSilenceableError()
81 <<
"unsupported shared outputs (didn't bufferize?)";
86 if (getNumResults() != lbs.size()) {
88 emitSilenceableError()
89 <<
"op expects as many results (" << getNumResults()
90 <<
") as payload has induction variables (" << lbs.size() <<
")";
91 diag.attachNote(target.getLoc()) <<
"payload op";
98 <<
"failed to convert forall into for";
103 results.
set(cast<OpResult>(getTransformed()[i]), {res});
116 auto payload = state.getPayloadOps(getTarget());
117 if (!llvm::hasSingleElement(payload))
118 return emitSilenceableError() <<
"expected a single payload op";
120 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
123 emitSilenceableError() <<
"expected the payload to be scf.forall";
124 diag.attachNote((*payload.begin())->getLoc()) <<
"payload op";
128 if (!target.getOutputs().empty()) {
129 return emitSilenceableError()
130 <<
"unsupported shared outputs (didn't bufferize?)";
133 if (getNumResults() != 1) {
135 <<
"op expects one result, given "
137 diag.attachNote(target.getLoc()) <<
"payload op";
141 scf::ParallelOp opResult;
144 emitSilenceableError() <<
"failed to convert forall into parallel";
148 results.
set(cast<OpResult>(getTransformed()[0]), {opResult});
165 scf::ExecuteRegionOp executeRegionOp =
172 assert(clonedRegion.
empty() &&
"expected empty region");
178 return executeRegionOp;
188 for (
Operation *target : state.getPayloadOps(getTarget())) {
189 Location location = target->getLoc();
194 <<
"failed to outline";
195 diag.attachNote(target->getLoc()) <<
"target op";
200 rewriter, location, exec.getRegion(), getFuncName(), &call);
202 if (failed(outlined))
203 return emitDefaultDefiniteFailure(target);
207 symbolTables.try_emplace(symbolTableOp, symbolTableOp)
209 symbolTable.
insert(*outlined);
212 functions.push_back(*outlined);
213 calls.push_back(call);
215 results.
set(cast<OpResult>(getFunction()), functions);
216 results.
set(cast<OpResult>(getCall()), calls);
230 if (getPeelFront()) {
231 LogicalResult status =
233 if (failed(status)) {
235 emitSilenceableError() <<
"failed to peel the first iteration";
239 LogicalResult status =
241 if (failed(status)) {
243 <<
"failed to peel the last iteration";
263 std::vector<std::pair<Operation *, unsigned>> &schedule,
264 unsigned iterationInterval,
unsigned readLatency) {
265 auto getLatency = [&](
Operation *op) ->
unsigned {
266 if (isa<vector::TransferReadOp>(op))
271 std::optional<int64_t> ubConstant =
273 std::optional<int64_t> lbConstant =
276 std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
277 for (
Operation &op : forOp.getBody()->getOperations()) {
278 if (isa<scf::YieldOp>(op))
280 unsigned earlyCycle = 0;
281 for (
Value operand : op.getOperands()) {
282 Operation *def = operand.getDefiningOp();
285 if (ubConstant && lbConstant) {
286 unsigned ubInt = ubConstant.value();
287 unsigned lbInt = lbConstant.value();
288 auto minLatency =
std::min(ubInt - lbInt - 1, getLatency(def));
289 earlyCycle =
std::max(earlyCycle, opCycles[def] + minLatency);
291 earlyCycle =
std::max(earlyCycle, opCycles[def] + getLatency(def));
294 opCycles[&op] = earlyCycle;
295 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
297 for (
const auto &it : wrappedSchedule) {
299 unsigned cycle = opCycles[op];
300 schedule.emplace_back(op, cycle / iterationInterval);
312 [
this](scf::ForOp forOp,
313 std::vector<std::pair<Operation *, unsigned>> &schedule)
mutable {
319 FailureOr<scf::ForOp> patternResult =
321 if (succeeded(patternResult)) {
325 return emitDefaultSilenceableFailure(target);
336 (void)target.promoteIfSingleIteration(rewriter);
340 void transform::LoopPromoteIfOneIterationOp::getEffects(
355 LogicalResult result(failure());
356 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
358 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
361 return emitSilenceableError()
362 <<
"failed to unroll, incorrect type of payload";
365 return emitSilenceableError() <<
"failed to unroll";
378 LogicalResult result(failure());
379 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
381 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
384 return emitSilenceableError()
385 <<
"failed to unroll and jam, incorrect type of payload";
388 return emitSilenceableError() <<
"failed to unroll and jam";
402 LogicalResult result(failure());
403 if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
405 else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
409 if (failed(result)) {
411 <<
"failed to coalesce";
424 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
439 getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
440 if (!llvm::hasSingleElement(region)) {
442 <<
"requires an scf.if op with a single-block "
443 << ((getTakeElseBranch()) ?
"`else`" :
"`then`") <<
" region";
449 void transform::TakeAssumedBranchOp::getEffects(
468 if (target == source)
470 <<
"target and source need to be different loops";
475 <<
"target and source are not in the same block";
485 <<
"user of results of target should be properly dominated by "
495 Operation *operandOp = operand.getDefiningOp();
505 <<
"operands of target should be properly dominated by source";
512 Operation *operandOp = operand->get().getDefiningOp();
513 if (operandOp && !domInfo.properlyDominates(operandOp, source,
518 failedValue = operand;
524 <<
"values used inside regions of target should be properly "
525 "dominated by source";
538 auto targetOp = dyn_cast<scf::ForallOp>(target);
539 auto sourceOp = dyn_cast<scf::ForallOp>(source);
540 if (!targetOp || !sourceOp)
543 return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
544 targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
545 targetOp.getMixedStep() == sourceOp.getMixedStep() &&
546 targetOp.getMapping() == sourceOp.getMapping();
556 auto targetOp = dyn_cast<scf::ForOp>(target);
557 auto sourceOp = dyn_cast<scf::ForOp>(source);
558 if (!targetOp || !sourceOp)
561 return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
562 targetOp.getUpperBound() == sourceOp.getUpperBound() &&
563 targetOp.getStep() == sourceOp.getStep();
570 auto targetOps = state.getPayloadOps(getTarget());
571 auto sourceOps = state.getPayloadOps(getSource());
573 if (!llvm::hasSingleElement(targetOps) ||
574 !llvm::hasSingleElement(sourceOps)) {
576 <<
"requires exactly one target handle (got "
577 << llvm::range_size(targetOps) <<
") and exactly one "
578 <<
"source handle (got " << llvm::range_size(sourceOps) <<
")";
586 if (!
diag.succeeded())
593 cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
596 cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
599 <<
"operations cannot be fused";
601 assert(fusedLoop &&
"failed to fuse operations");
603 results.
set(cast<OpResult>(getFusedLoop()), {fusedLoop});
612 class SCFTransformDialectExtension
614 SCFTransformDialectExtension> {
621 declareGeneratedDialect<affine::AffineDialect>();
622 declareGeneratedDialect<func::FuncDialect>();
624 registerTransformOps<
626 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
632 #define GET_OP_CLASSES
633 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class describes a specific conversion target.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, bool cleanUpUnroll=false)
Unrolls this for operation by the specified unroll factor.
LogicalResult loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor)
Unrolls and jams this loop by the specified factor.
LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op)
Walk an affine.for to find a band to coalesce.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerTransformDialectExtension(DialectRegistry ®istry)
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp, SmallVectorImpl< Operation * > *results=nullptr)
Try converting scf.forall into a set of nested scf.for loops.
LogicalResult peelForLoopFirstIteration(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Peel the first iteration out of the scf.for loop.
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Options to dictate how loops should be pipelined.