35 void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
40 void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
45 void transform::ApplySCFStructuralConversionPatternsOp::
46 populateConversionTargetRules(
const TypeConverter &typeConverter,
60 auto payload = state.getPayloadOps(getTarget());
61 if (!llvm::hasSingleElement(payload))
62 return emitSilenceableError() <<
"expected a single payload op";
64 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
67 emitSilenceableError() <<
"expected the payload to be scf.forall";
68 diag.attachNote((*payload.begin())->getLoc()) <<
"payload op";
74 if (!target.getOutputs().empty()) {
75 return emitSilenceableError()
76 <<
"unsupported shared outputs (didn't bufferize?)";
83 if (getNumResults() != lbs.size()) {
85 emitSilenceableError()
86 <<
"op expects as many results (" << getNumResults()
87 <<
") as payload has induction variables (" << lbs.size() <<
")";
88 diag.attachNote(target.getLoc()) <<
"payload op";
92 auto loc = target.getLoc();
94 for (
auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
98 auto loop = rewriter.
create<scf::ForOp>(
99 loc, lbValue, ubValue, stepValue,
ValueRange(),
101 ivs.push_back(loop.getInductionVar());
103 rewriter.
create<scf::YieldOp>(loc);
106 rewriter.
eraseOp(target.getBody()->getTerminator());
112 results.
set(cast<OpResult>(getTransformed()[i]),
113 {iv.getParentBlock()->getParentOp()});
131 scf::ExecuteRegionOp executeRegionOp =
138 assert(clonedRegion.
empty() &&
"expected empty region");
144 return executeRegionOp;
154 for (
Operation *target : state.getPayloadOps(getTarget())) {
155 Location location = target->getLoc();
160 <<
"failed to outline";
161 diag.attachNote(target->getLoc()) <<
"target op";
166 rewriter, location, exec.getRegion(), getFuncName(), &call);
169 return emitDefaultDefiniteFailure(target);
173 symbolTables.try_emplace(symbolTableOp, symbolTableOp)
175 symbolTable.
insert(*outlined);
178 functions.push_back(*outlined);
179 calls.push_back(call);
181 results.
set(cast<OpResult>(getFunction()), functions);
182 results.
set(cast<OpResult>(getCall()), calls);
196 if (getPeelFront()) {
201 emitSilenceableError() <<
"failed to peel the first iteration";
209 <<
"failed to peel the last iteration";
229 std::vector<std::pair<Operation *, unsigned>> &schedule,
230 unsigned iterationInterval,
unsigned readLatency) {
231 auto getLatency = [&](
Operation *op) ->
unsigned {
232 if (isa<vector::TransferReadOp>(op))
238 std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
239 for (
Operation &op : forOp.getBody()->getOperations()) {
240 if (isa<scf::YieldOp>(op))
242 unsigned earlyCycle = 0;
244 Operation *def = operand.getDefiningOp();
247 earlyCycle =
std::max(earlyCycle, opCycles[def] + getLatency(def));
249 opCycles[&op] = earlyCycle;
250 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
252 for (
const auto &it : wrappedSchedule) {
254 unsigned cycle = opCycles[op];
255 schedule.emplace_back(op, cycle / iterationInterval);
267 [
this](scf::ForOp forOp,
268 std::vector<std::pair<Operation *, unsigned>> &schedule)
mutable {
280 return emitDefaultSilenceableFailure(target);
291 (void)target.promoteIfSingleIteration(rewriter);
295 void transform::LoopPromoteIfOneIterationOp::getEffects(
311 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
313 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
318 <<
"failed to unroll";
334 if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
336 else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
342 <<
"failed to coalesce";
355 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
370 getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
371 if (!llvm::hasSingleElement(region)) {
373 <<
"requires an scf.if op with a single-block "
374 << ((getTakeElseBranch()) ?
"`else`" :
"`then`") <<
" region";
380 void transform::TakeAssumedBranchOp::getEffects(
399 if (target == source)
401 <<
"target and source need to be different loops";
406 <<
"target and source are not in the same block";
416 <<
"user of results of target should be properly dominated by "
426 Operation *operandOp = operand.getDefiningOp();
436 <<
"operands of target should be properly dominated by source";
443 Operation *operandOp = operand->get().getDefiningOp();
444 if (operandOp && !domInfo.properlyDominates(operandOp, source,
449 failedValue = operand;
455 <<
"values used inside regions of target should be properly "
456 "dominated by source";
469 auto targetOp = dyn_cast<scf::ForallOp>(target);
470 auto sourceOp = dyn_cast<scf::ForallOp>(source);
471 if (!targetOp || !sourceOp)
474 return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
475 targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
476 targetOp.getMixedStep() == sourceOp.getMixedStep() &&
477 targetOp.getMapping() == sourceOp.getMapping();
487 auto targetOp = dyn_cast<scf::ForOp>(target);
488 auto sourceOp = dyn_cast<scf::ForOp>(source);
489 if (!targetOp || !sourceOp)
492 return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
493 targetOp.getUpperBound() == sourceOp.getUpperBound() &&
494 targetOp.getStep() == sourceOp.getStep();
501 auto targetOps = state.getPayloadOps(getTarget());
502 auto sourceOps = state.getPayloadOps(getSource());
504 if (!llvm::hasSingleElement(targetOps) ||
505 !llvm::hasSingleElement(sourceOps)) {
507 <<
"requires exactly one target handle (got "
508 << llvm::range_size(targetOps) <<
") and exactly one "
509 <<
"source handle (got " << llvm::range_size(sourceOps) <<
")";
517 if (!
diag.succeeded())
524 cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
527 cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
530 <<
"operations cannot be fused";
532 assert(fusedLoop &&
"failed to fuse operations");
534 results.
set(cast<OpResult>(getFusedLoop()), {fusedLoop});
543 class SCFTransformDialectExtension
545 SCFTransformDialectExtension> {
550 declareGeneratedDialect<affine::AffineDialect>();
551 declareGeneratedDialect<func::FuncDialect>();
553 registerTransformOps<
555 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
561 #define GET_OP_CLASSES
562 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class describes a specific conversion target.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class provides support for representing a failure result, or a valid value of type T.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, bool cleanUpUnroll=false)
Unrolls this for operation by the specified unroll factor.
LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op)
Walk either an scf.for or an affine.for to find a band to coalesce.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerTransformDialectExtension(DialectRegistry ®istry)
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
LogicalResult peelForLoopFirstIteration(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Peel the first iteration out of the scf.for loop.
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void populateSCFStructuralTypeConversions(TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options to dictate how loops should be pipelined.