MLIR  16.0.0git
SCFTransformOps.cpp
Go to the documentation of this file.
1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // GetParentForOp
26 //===----------------------------------------------------------------------===//
28 transform::GetParentForOp::apply(transform::TransformResults &results,
30  SetVector<Operation *> parents;
31  for (Operation *target : state.getPayloadOps(getTarget())) {
32  Operation *loop, *current = target;
33  for (unsigned i = 0, e = getNumLoops(); i < e; ++i) {
34  loop = getAffine()
35  ? current->getParentOfType<AffineForOp>().getOperation()
36  : current->getParentOfType<scf::ForOp>().getOperation();
37  if (!loop) {
39  emitSilenceableError()
40  << "could not find an '"
41  << (getAffine() ? AffineForOp::getOperationName()
42  : scf::ForOp::getOperationName())
43  << "' parent";
44  diag.attachNote(target->getLoc()) << "target op";
45  results.set(getResult().cast<OpResult>(), {});
46  return diag;
47  }
48  current = loop;
49  }
50  parents.insert(loop);
51  }
52  results.set(getResult().cast<OpResult>(), parents.getArrayRef());
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // LoopOutlineOp
58 //===----------------------------------------------------------------------===//
59 
60 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
61 /// the provided rewriter for all operations to remain compatible with the
62 /// rewriting infra, as opposed to just splicing the op in place.
63 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
64  Operation *op) {
65  if (op->getNumRegions() != 1)
66  return nullptr;
68  b.setInsertionPoint(op);
69  scf::ExecuteRegionOp executeRegionOp =
70  b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
71  {
73  b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
74  Operation *clonedOp = b.cloneWithoutRegions(*op);
75  Region &clonedRegion = clonedOp->getRegions().front();
76  assert(clonedRegion.empty() && "expected empty region");
77  b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
78  clonedRegion.end());
79  b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
80  }
81  b.replaceOp(op, executeRegionOp.getResults());
82  return executeRegionOp;
83 }
84 
86 transform::LoopOutlineOp::apply(transform::TransformResults &results,
88  SmallVector<Operation *> transformed;
90  for (Operation *target : state.getPayloadOps(getTarget())) {
91  Location location = target->getLoc();
92  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
93  TrivialPatternRewriter rewriter(getContext());
94  scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
95  if (!exec) {
96  DiagnosedSilenceableFailure diag = emitSilenceableError()
97  << "failed to outline";
98  diag.attachNote(target->getLoc()) << "target op";
99  results.set(getTransformed().cast<OpResult>(), {});
100  return diag;
101  }
102  func::CallOp call;
104  rewriter, location, exec.getRegion(), getFuncName(), &call);
105 
106  if (failed(outlined)) {
107  (void)reportUnknownTransformError(target);
109  }
110 
111  if (symbolTableOp) {
112  SymbolTable &symbolTable =
113  symbolTables.try_emplace(symbolTableOp, symbolTableOp)
114  .first->getSecond();
115  symbolTable.insert(*outlined);
116  call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
117  }
118  transformed.push_back(*outlined);
119  }
120  results.set(getTransformed().cast<OpResult>(), transformed);
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // LoopPeelOp
126 //===----------------------------------------------------------------------===//
127 
129 transform::LoopPeelOp::applyToOne(scf::ForOp target,
130  SmallVector<Operation *> &results,
131  transform::TransformState &state) {
132  scf::ForOp result;
133  IRRewriter rewriter(target->getContext());
134  // This helper returns failure when peeling does not occur (i.e. when the IR
135  // is not modified). This is not a failure for the op as the postcondition:
136  // "the loop trip count is divisible by the step"
137  // is valid.
138  LogicalResult status =
139  scf::peelAndCanonicalizeForLoop(rewriter, target, result);
140  // TODO: Return both the peeled loop and the remainder loop.
141  results.push_back(failed(status) ? target : result);
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // LoopPipelineOp
147 //===----------------------------------------------------------------------===//
148 
149 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an
150 /// operation to its logical time position given the iteration interval and the
151 /// read latency. The latter is only relevant for vector transfers.
152 static void
153 loopScheduling(scf::ForOp forOp,
154  std::vector<std::pair<Operation *, unsigned>> &schedule,
155  unsigned iterationInterval, unsigned readLatency) {
156  auto getLatency = [&](Operation *op) -> unsigned {
157  if (isa<vector::TransferReadOp>(op))
158  return readLatency;
159  return 1;
160  };
161 
163  std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
164  for (Operation &op : forOp.getBody()->getOperations()) {
165  if (isa<scf::YieldOp>(op))
166  continue;
167  unsigned earlyCycle = 0;
168  for (Value operand : op.getOperands()) {
169  Operation *def = operand.getDefiningOp();
170  if (!def)
171  continue;
172  earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
173  }
174  opCycles[&op] = earlyCycle;
175  wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
176  }
177  for (const auto &it : wrappedSchedule) {
178  for (Operation *op : it.second) {
179  unsigned cycle = opCycles[op];
180  schedule.emplace_back(op, cycle / iterationInterval);
181  }
182  }
183 }
184 
186 transform::LoopPipelineOp::applyToOne(scf::ForOp target,
187  SmallVector<Operation *> &results,
188  transform::TransformState &state) {
190  options.getScheduleFn =
191  [this](scf::ForOp forOp,
192  std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
193  loopScheduling(forOp, schedule, getIterationInterval(),
194  getReadLatency());
195  };
196  scf::ForLoopPipeliningPattern pattern(options, target->getContext());
197  TrivialPatternRewriter rewriter(getContext());
198  rewriter.setInsertionPoint(target);
199  FailureOr<scf::ForOp> patternResult =
200  pattern.returningMatchAndRewrite(target, rewriter);
201  if (succeeded(patternResult)) {
202  results.push_back(*patternResult);
204  }
205  results.assign(1, nullptr);
206  return emitDefaultSilenceableFailure(target);
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // LoopUnrollOp
211 //===----------------------------------------------------------------------===//
212 
214 transform::LoopUnrollOp::applyToOne(Operation *op,
215  SmallVector<Operation *> &results,
216  transform::TransformState &state) {
217  LogicalResult result(failure());
218  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
219  result = loopUnrollByFactor(scfFor, getFactor());
220  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
221  result = loopUnrollByFactor(affineFor, getFactor());
222 
223  if (failed(result)) {
225  diag << "Op failed to unroll";
227  }
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // Transform op registration
233 //===----------------------------------------------------------------------===//
234 
235 namespace {
236 class SCFTransformDialectExtension
238  SCFTransformDialectExtension> {
239 public:
240  using Base::Base;
241 
242  void init() {
243  declareGeneratedDialect<AffineDialect>();
244  declareGeneratedDialect<func::FuncDialect>();
245 
246  registerTransformOps<
247 #define GET_OP_LIST
248 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
249  >();
250  }
251 };
252 } // namespace
253 
254 #define GET_OP_CLASSES
255 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
256 
258  registry.addExtensions<SCFTransformDialectExtension>();
259 }
static std::string diag(llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void loopScheduling(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &schedule, unsigned iterationInterval, unsigned readLatency)
Callback for PipeliningOption.
static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, Operation *op)
Wraps the given operation op into an scf.execute_region operation.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:594
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
Operation * cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:526
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:477
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:169
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
result_type_range getResultTypes()
Definition: Operation.h:345
result_range getResults()
Definition: Operation.h:332
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:23
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Definition: Patterns.h:34
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, ArrayRef< Operation * > ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
The state maintained across applications of various ops implementing the TransformOpInterface.
ArrayRef< Operation * > getPayloadOps(Value value) const
Returns the list of ops that the given transform IR value corresponds to.
void registerTransformDialectExtension(DialectRegistry &registry)
LogicalResult peelAndCanonicalizeForLoop(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
Include the generated interface declarations.
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, bool cleanUpUnroll=false)
Unrolls this for operation by the specified unroll factor.
Definition: LoopUtils.cpp:1091
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition: Utils.cpp:187
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options to dictate how loops should be pipelined.
Definition: Transforms.h:124