35 void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
40 void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
45 void transform::ApplySCFStructuralConversionPatternsOp::
46 populateConversionTargetRules(
const TypeConverter &typeConverter,
60 auto payload = state.getPayloadOps(getTarget());
61 if (!llvm::hasSingleElement(payload))
62 return emitSilenceableError() <<
"expected a single payload op";
64 auto target = dyn_cast<scf::ForallOp>(*payload.begin());
67 emitSilenceableError() <<
"expected the payload to be scf.forall";
68 diag.attachNote((*payload.begin())->getLoc()) <<
"payload op";
74 if (!target.getOutputs().empty()) {
75 return emitSilenceableError()
76 <<
"unsupported shared outputs (didn't bufferize?)";
83 if (getNumResults() != lbs.size()) {
85 emitSilenceableError()
86 <<
"op expects as many results (" << getNumResults()
87 <<
") as payload has induction variables (" << lbs.size() <<
")";
88 diag.attachNote(target.getLoc()) <<
"payload op";
92 auto loc = target.getLoc();
94 for (
auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
98 auto loop = rewriter.
create<scf::ForOp>(
99 loc, lbValue, ubValue, stepValue,
ValueRange(),
101 ivs.push_back(loop.getInductionVar());
103 rewriter.
create<scf::YieldOp>(loc);
106 rewriter.
eraseOp(target.getBody()->getTerminator());
112 results.
set(cast<OpResult>(getTransformed()[i]),
113 {iv.getParentBlock()->getParentOp()});
131 scf::ExecuteRegionOp executeRegionOp =
138 assert(clonedRegion.
empty() &&
"expected empty region");
144 return executeRegionOp;
154 for (
Operation *target : state.getPayloadOps(getTarget())) {
155 Location location = target->getLoc();
160 <<
"failed to outline";
161 diag.attachNote(target->getLoc()) <<
"target op";
166 rewriter, location, exec.getRegion(), getFuncName(), &call);
169 return emitDefaultDefiniteFailure(target);
173 symbolTables.try_emplace(symbolTableOp, symbolTableOp)
175 symbolTable.
insert(*outlined);
178 functions.push_back(*outlined);
179 calls.push_back(call);
181 results.
set(cast<OpResult>(getFunction()), functions);
182 results.
set(cast<OpResult>(getCall()), calls);
218 std::vector<std::pair<Operation *, unsigned>> &schedule,
219 unsigned iterationInterval,
unsigned readLatency) {
220 auto getLatency = [&](
Operation *op) ->
unsigned {
221 if (isa<vector::TransferReadOp>(op))
227 std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
228 for (
Operation &op : forOp.getBody()->getOperations()) {
229 if (isa<scf::YieldOp>(op))
231 unsigned earlyCycle = 0;
233 Operation *def = operand.getDefiningOp();
236 earlyCycle =
std::max(earlyCycle, opCycles[def] + getLatency(def));
238 opCycles[&op] = earlyCycle;
239 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
241 for (
const auto &it : wrappedSchedule) {
243 unsigned cycle = opCycles[op];
244 schedule.emplace_back(op, cycle / iterationInterval);
256 [
this](scf::ForOp forOp,
257 std::vector<std::pair<Operation *, unsigned>> &schedule)
mutable {
269 return emitDefaultSilenceableFailure(target);
280 (void)target.promoteIfSingleIteration(rewriter);
284 void transform::LoopPromoteIfOneIterationOp::getEffects(
300 if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
302 else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
307 <<
"failed to unroll";
323 if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
325 else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
331 <<
"failed to coalesce";
344 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
359 getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
360 if (!llvm::hasSingleElement(region)) {
362 <<
"requires an scf.if op with a single-block "
363 << ((getTakeElseBranch()) ?
"`else`" :
"`then`") <<
" region";
369 void transform::TakeAssumedBranchOp::getEffects(
388 if (target == source)
390 <<
"target and source need to be different loops";
395 <<
"target and source are not in the same block";
405 <<
"user of results of target should be properly dominated by "
415 Operation *operandOp = operand.getDefiningOp();
426 <<
"operands of target should be properly dominated by source";
433 if (!domInfo.properlyDominates(operand->getOwner(), source,
436 failedValue = operand;
442 <<
"values used inside regions of target should be properly "
443 "dominated by source";
457 auto targetOp = dyn_cast<scf::ForallOp>(target);
458 auto sourceOp = dyn_cast<scf::ForallOp>(source);
459 if (!targetOp || !sourceOp)
462 return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
463 targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
464 targetOp.getMixedStep() == sourceOp.getMixedStep() &&
465 targetOp.getMapping() == sourceOp.getMapping();
472 auto targetOp = dyn_cast<scf::ForallOp>(target);
473 auto sourceOp = dyn_cast<scf::ForallOp>(source);
474 if (!targetOp || !sourceOp)
483 auto targetOps = state.getPayloadOps(getTarget());
484 auto sourceOps = state.getPayloadOps(getSource());
486 if (!llvm::hasSingleElement(targetOps) ||
487 !llvm::hasSingleElement(sourceOps)) {
489 <<
"requires exactly one target handle (got "
490 << llvm::range_size(targetOps) <<
") and exactly one "
491 <<
"source handle (got " << llvm::range_size(sourceOps) <<
")";
499 if (!
diag.succeeded())
505 <<
"operations cannot be fused";
509 assert(fusedLoop &&
"failed to fuse operations");
511 results.
set(cast<OpResult>(getFusedLoop()), {fusedLoop});
520 class SCFTransformDialectExtension
522 SCFTransformDialectExtension> {
527 declareGeneratedDialect<affine::AffineDialect>();
528 declareGeneratedDialect<func::FuncDialect>();
530 registerTransformOps<
532 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
538 #define GET_OP_CLASSES
539 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
This class describes a specific conversion target.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class provides support for representing a failure result, or a valid value of type T.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, bool cleanUpUnroll=false)
Unrolls this for operation by the specified unroll factor.
LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op)
Walk either an scf.for or an affine.for to find a band to coalesce.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerTransformDialectExtension(DialectRegistry ®istry)
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void populateSCFStructuralTypeConversions(TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options to dictate how loops should be pipelined.