MLIR  16.0.0git
SCFTransformOps.cpp
Go to the documentation of this file.
1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
20 
21 using namespace mlir;
22 
23 namespace {
24 /// A simple pattern rewriter that implements no special logic.
25 class SimpleRewriter : public PatternRewriter {
26 public:
27  SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
28 };
29 } // namespace
30 
31 //===----------------------------------------------------------------------===//
32 // GetParentForOp
33 //===----------------------------------------------------------------------===//
34 
36 transform::GetParentForOp::apply(transform::TransformResults &results,
38  SetVector<Operation *> parents;
39  for (Operation *target : state.getPayloadOps(getTarget())) {
40  scf::ForOp loop;
41  Operation *current = target;
42  for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
43  loop = current->getParentOfType<scf::ForOp>();
44  if (!loop) {
45  DiagnosedSilenceableFailure diag = emitSilenceableError()
46  << "could not find an '"
47  << scf::ForOp::getOperationName()
48  << "' parent";
49  diag.attachNote(target->getLoc()) << "target op";
50  return diag;
51  }
52  current = loop;
53  }
54  parents.insert(loop);
55  }
56  results.set(getResult().cast<OpResult>(), parents.getArrayRef());
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // LoopOutlineOp
62 //===----------------------------------------------------------------------===//
63 
64 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
65 /// the provided rewriter for all operations to remain compatible with the
66 /// rewriting infra, as opposed to just splicing the op in place.
67 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
68  Operation *op) {
69  if (op->getNumRegions() != 1)
70  return nullptr;
72  b.setInsertionPoint(op);
73  scf::ExecuteRegionOp executeRegionOp =
74  b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
75  {
77  b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
78  Operation *clonedOp = b.cloneWithoutRegions(*op);
79  Region &clonedRegion = clonedOp->getRegions().front();
80  assert(clonedRegion.empty() && "expected empty region");
81  b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
82  clonedRegion.end());
83  b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
84  }
85  b.replaceOp(op, executeRegionOp.getResults());
86  return executeRegionOp;
87 }
88 
90 transform::LoopOutlineOp::apply(transform::TransformResults &results,
92  SmallVector<Operation *> transformed;
94  for (Operation *target : state.getPayloadOps(getTarget())) {
95  Location location = target->getLoc();
96  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
97  SimpleRewriter rewriter(getContext());
98  scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
99  if (!exec) {
100  DiagnosedSilenceableFailure diag = emitSilenceableError()
101  << "failed to outline";
102  diag.attachNote(target->getLoc()) << "target op";
103  return diag;
104  }
105  func::CallOp call;
107  rewriter, location, exec.getRegion(), getFuncName(), &call);
108 
109  if (failed(outlined)) {
110  (void)reportUnknownTransformError(target);
112  }
113 
114  if (symbolTableOp) {
115  SymbolTable &symbolTable =
116  symbolTables.try_emplace(symbolTableOp, symbolTableOp)
117  .first->getSecond();
118  symbolTable.insert(*outlined);
119  call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
120  }
121  transformed.push_back(*outlined);
122  }
123  results.set(getTransformed().cast<OpResult>(), transformed);
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // LoopPeelOp
129 //===----------------------------------------------------------------------===//
130 
132 transform::LoopPeelOp::applyToOne(scf::ForOp target,
133  SmallVector<Operation *> &results,
134  transform::TransformState &state) {
135  scf::ForOp result;
136  IRRewriter rewriter(target->getContext());
137  // This helper returns failure when peeling does not occur (i.e. when the IR
138  // is not modified). This is not a failure for the op as the postcondition:
139  // "the loop trip count is divisible by the step"
140  // is valid.
141  LogicalResult status =
142  scf::peelAndCanonicalizeForLoop(rewriter, target, result);
143  // TODO: Return both the peeled loop and the remainder loop.
144  results.push_back(failed(status) ? target : result);
146 }
147 
148 //===----------------------------------------------------------------------===//
149 // LoopPipelineOp
150 //===----------------------------------------------------------------------===//
151 
152 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an
153 /// operation to its logical time position given the iteration interval and the
154 /// read latency. The latter is only relevant for vector transfers.
155 static void
156 loopScheduling(scf::ForOp forOp,
157  std::vector<std::pair<Operation *, unsigned>> &schedule,
158  unsigned iterationInterval, unsigned readLatency) {
159  auto getLatency = [&](Operation *op) -> unsigned {
160  if (isa<vector::TransferReadOp>(op))
161  return readLatency;
162  return 1;
163  };
164 
166  std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
167  for (Operation &op : forOp.getBody()->getOperations()) {
168  if (isa<scf::YieldOp>(op))
169  continue;
170  unsigned earlyCycle = 0;
171  for (Value operand : op.getOperands()) {
172  Operation *def = operand.getDefiningOp();
173  if (!def)
174  continue;
175  earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
176  }
177  opCycles[&op] = earlyCycle;
178  wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
179  }
180  for (const auto &it : wrappedSchedule) {
181  for (Operation *op : it.second) {
182  unsigned cycle = opCycles[op];
183  schedule.emplace_back(op, cycle / iterationInterval);
184  }
185  }
186 }
187 
189 transform::LoopPipelineOp::applyToOne(scf::ForOp target,
190  SmallVector<Operation *> &results,
191  transform::TransformState &state) {
193  options.getScheduleFn =
194  [this](scf::ForOp forOp,
195  std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
196  loopScheduling(forOp, schedule, getIterationInterval(),
197  getReadLatency());
198  };
199  scf::ForLoopPipeliningPattern pattern(options, target->getContext());
200  SimpleRewriter rewriter(getContext());
201  rewriter.setInsertionPoint(target);
202  FailureOr<scf::ForOp> patternResult =
203  pattern.returningMatchAndRewrite(target, rewriter);
204  if (succeeded(patternResult)) {
205  results.push_back(*patternResult);
207  }
208  results.assign(1, nullptr);
209  return emitDefaultSilenceableFailure(target);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // LoopUnrollOp
214 //===----------------------------------------------------------------------===//
215 
217 transform::LoopUnrollOp::applyToOne(scf::ForOp target,
218  SmallVector<Operation *> &results,
219  transform::TransformState &state) {
220  if (failed(loopUnrollByFactor(target, getFactor()))) {
221  Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note);
222  diag << "op failed to unroll";
223  return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
224  }
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // Transform op registration
230 //===----------------------------------------------------------------------===//
231 
232 namespace {
233 class SCFTransformDialectExtension
235  SCFTransformDialectExtension> {
236 public:
237  using Base::Base;
238 
239  void init() {
240  declareDependentDialect<pdl::PDLDialect>();
241 
242  declareGeneratedDialect<AffineDialect>();
243  declareGeneratedDialect<func::FuncDialect>();
244 
245  registerTransformOps<
246 #define GET_OP_LIST
247 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
248  >();
249  }
250 };
251 } // namespace
252 
253 #define GET_OP_CLASSES
254 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
255 
257  registry.addExtensions<SCFTransformDialectExtension>();
258 }
Diagnostic & attachNote(Optional< Location > loc=llvm::None)
Attaches a note to the last diagnostic.
Include the generated interface declarations.
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition: Utils.cpp:139
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
void set(OpResult value, ArrayRef< Operation *> ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
static std::string diag(llvm::Value &v)
The result of a transform IR operation application.
static void loopScheduling(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &schedule, unsigned iterationInterval, unsigned readLatency)
Callback for PipeliningOption.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:344
Operation * cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:522
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
ArrayRef< Operation * > getPayloadOps(Value value) const
Returns the list of ops that the given transform IR value corresponds to.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type &#39;OpTy&#39;.
Definition: Operation.h:169
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
Definition: Diagnostics.h:157
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
bool empty()
Definition: Region.h:60
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
void addExtensions()
Add the given extensions to the registry.
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
Definition: LoopUtils.cpp:1091
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Definition: Patterns.h:34
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
The state maintained across applications of various ops implementing the TransformOpInterface.
GetScheduleFnType getScheduleFn
Definition: Transforms.h:129
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:584
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
static llvm::ManagedStatic< PassManagerOptions > options
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
iterator end()
Definition: Region.h:56
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This class allows for representing and managing the symbol table used by operations with the &#39;SymbolT...
Definition: SymbolTable.h:23
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, Operation *op)
Wraps the given operation op into an scf.execute_region operation.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
Options to dictate how loops should be pipelined.
Definition: Transforms.h:124
result_range getResults()
Definition: Operation.h:332
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
result_type_range getResultTypes()
Definition: Operation.h:345
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
void registerTransformDialectExtension(DialectRegistry &registry)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)