42 for (
unsigned i = 0, e = getNumLoops(); i < e; ++i) {
46 <<
"could not find an '" 47 << scf::ForOp::getOperationName()
49 diag.
attachNote(target->getLoc()) <<
"target op";
56 results.
set(getResult().cast<OpResult>(), parents.getArrayRef());
73 scf::ExecuteRegionOp executeRegionOp =
80 assert(clonedRegion.
empty() &&
"expected empty region");
86 return executeRegionOp;
95 Location location = target->getLoc();
97 SimpleRewriter rewriter(getContext());
101 <<
"failed to outline";
102 diag.
attachNote(target->getLoc()) <<
"target op";
107 rewriter, location, exec.getRegion(), getFuncName(), &call);
110 (
void)reportUnknownTransformError(target);
116 symbolTables.try_emplace(symbolTableOp, symbolTableOp)
118 symbolTable.
insert(*outlined);
121 transformed.push_back(*outlined);
123 results.
set(getTransformed().cast<OpResult>(), transformed);
132 transform::LoopPeelOp::applyToOne(scf::ForOp target,
144 results.push_back(
failed(status) ? target : result);
157 std::vector<std::pair<Operation *, unsigned>> &schedule,
158 unsigned iterationInterval,
unsigned readLatency) {
159 auto getLatency = [&](
Operation *op) ->
unsigned {
160 if (isa<vector::TransferReadOp>(op))
166 std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
167 for (
Operation &op : forOp.getBody()->getOperations()) {
168 if (isa<scf::YieldOp>(op))
170 unsigned earlyCycle = 0;
171 for (
Value operand : op.getOperands()) {
172 Operation *def = operand.getDefiningOp();
175 earlyCycle =
std::max(earlyCycle, opCycles[def] + getLatency(def));
177 opCycles[&op] = earlyCycle;
178 wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
180 for (
const auto &it : wrappedSchedule) {
182 unsigned cycle = opCycles[op];
183 schedule.emplace_back(op, cycle / iterationInterval);
189 transform::LoopPipelineOp::applyToOne(scf::ForOp target,
194 [
this](scf::ForOp forOp,
195 std::vector<std::pair<Operation *, unsigned>> &schedule)
mutable {
200 SimpleRewriter rewriter(getContext());
201 rewriter.setInsertionPoint(target);
203 pattern.returningMatchAndRewrite(target, rewriter);
205 results.push_back(*patternResult);
208 results.assign(1,
nullptr);
209 return emitDefaultSilenceableFailure(target);
217 transform::LoopUnrollOp::applyToOne(scf::ForOp target,
222 diag <<
"op failed to unroll";
233 class SCFTransformDialectExtension
235 SCFTransformDialectExtension> {
240 declareDependentDialect<pdl::PDLDialect>();
242 declareGeneratedDialect<AffineDialect>();
243 declareGeneratedDialect<func::FuncDialect>();
245 registerTransformOps<
247 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 253 #define GET_OP_CLASSES 254 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" Diagnostic & attachNote(Optional< Location > loc=llvm::None)
Attaches a note to the last diagnostic.
Include the generated interface declarations.
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
static std::string diag(llvm::Value &v)
The result of a transform IR operation application.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Operation is a basic unit of execution within MLIR.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
This class represents an efficient way to signal success or failure.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class provides support for representing a failure result, or a valid value of type T...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
void addExtensions()
Add the given extensions to the registry.
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Location getLoc()
The source location the operation was defined or derived from.
GetScheduleFnType getScheduleFn
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
static llvm::ManagedStatic< PassManagerOptions > options
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
RAII guard to reset the insertion point of the builder when destroyed.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
MLIRContext is the top-level object for a collection of MLIR operations.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
Options to dictate how loops should be pipelined.
result_range getResults()
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
result_type_range getResultTypes()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void registerTransformDialectExtension(DialectRegistry ®istry)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)