MLIR  18.0.0git
SCFTransformOps.cpp
Go to the documentation of this file.
1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
25 #include "mlir/IR/Dominance.h"
26 #include "mlir/IR/OpDefinition.h"
27 
28 using namespace mlir;
29 using namespace mlir::affine;
30 
31 //===----------------------------------------------------------------------===//
32 // Apply...PatternsOp
33 //===----------------------------------------------------------------------===//
34 
35 void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
36  RewritePatternSet &patterns) {
38 }
39 
40 void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
41  TypeConverter &typeConverter, RewritePatternSet &patterns) {
42  scf::populateSCFStructuralTypeConversions(typeConverter, patterns);
43 }
44 
45 void transform::ApplySCFStructuralConversionPatternsOp::
46  populateConversionTargetRules(const TypeConverter &typeConverter,
47  ConversionTarget &conversionTarget) {
49  conversionTarget);
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // ForallToForOp
54 //===----------------------------------------------------------------------===//
55 
57 transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
60  auto payload = state.getPayloadOps(getTarget());
61  if (!llvm::hasSingleElement(payload))
62  return emitSilenceableError() << "expected a single payload op";
63 
64  auto target = dyn_cast<scf::ForallOp>(*payload.begin());
65  if (!target) {
67  emitSilenceableError() << "expected the payload to be scf.forall";
68  diag.attachNote((*payload.begin())->getLoc()) << "payload op";
69  return diag;
70  }
71 
72  rewriter.setInsertionPoint(target);
73 
74  if (!target.getOutputs().empty()) {
75  return emitSilenceableError()
76  << "unsupported shared outputs (didn't bufferize?)";
77  }
78 
79  SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
80  SmallVector<OpFoldResult> ubs = target.getMixedUpperBound();
81  SmallVector<OpFoldResult> steps = target.getMixedStep();
82 
83  if (getNumResults() != lbs.size()) {
85  emitSilenceableError()
86  << "op expects as many results (" << getNumResults()
87  << ") as payload has induction variables (" << lbs.size() << ")";
88  diag.attachNote(target.getLoc()) << "payload op";
89  return diag;
90  }
91 
92  auto loc = target.getLoc();
94  for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
95  Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb);
96  Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub);
97  Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step);
98  auto loop = rewriter.create<scf::ForOp>(
99  loc, lbValue, ubValue, stepValue, ValueRange(),
100  [](OpBuilder &, Location, Value, ValueRange) {});
101  ivs.push_back(loop.getInductionVar());
102  rewriter.setInsertionPointToStart(loop.getBody());
103  rewriter.create<scf::YieldOp>(loc);
104  rewriter.setInsertionPointToStart(loop.getBody());
105  }
106  rewriter.eraseOp(target.getBody()->getTerminator());
107  rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(),
108  ivs);
109  rewriter.eraseOp(target);
110 
111  for (auto &&[i, iv] : llvm::enumerate(ivs)) {
112  results.set(cast<OpResult>(getTransformed()[i]),
113  {iv.getParentBlock()->getParentOp()});
114  }
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // LoopOutlineOp
120 //===----------------------------------------------------------------------===//
121 
122 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
123 /// the provided rewriter for all operations to remain compatible with the
124 /// rewriting infra, as opposed to just splicing the op in place.
125 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
126  Operation *op) {
127  if (op->getNumRegions() != 1)
128  return nullptr;
130  b.setInsertionPoint(op);
131  scf::ExecuteRegionOp executeRegionOp =
132  b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
133  {
135  b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
136  Operation *clonedOp = b.cloneWithoutRegions(*op);
137  Region &clonedRegion = clonedOp->getRegions().front();
138  assert(clonedRegion.empty() && "expected empty region");
139  b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
140  clonedRegion.end());
141  b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
142  }
143  b.replaceOp(op, executeRegionOp.getResults());
144  return executeRegionOp;
145 }
146 
148 transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter,
150  transform::TransformState &state) {
151  SmallVector<Operation *> functions;
154  for (Operation *target : state.getPayloadOps(getTarget())) {
155  Location location = target->getLoc();
156  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
157  scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
158  if (!exec) {
159  DiagnosedSilenceableFailure diag = emitSilenceableError()
160  << "failed to outline";
161  diag.attachNote(target->getLoc()) << "target op";
162  return diag;
163  }
164  func::CallOp call;
166  rewriter, location, exec.getRegion(), getFuncName(), &call);
167 
168  if (failed(outlined))
169  return emitDefaultDefiniteFailure(target);
170 
171  if (symbolTableOp) {
172  SymbolTable &symbolTable =
173  symbolTables.try_emplace(symbolTableOp, symbolTableOp)
174  .first->getSecond();
175  symbolTable.insert(*outlined);
176  call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
177  }
178  functions.push_back(*outlined);
179  calls.push_back(call);
180  }
181  results.set(cast<OpResult>(getFunction()), functions);
182  results.set(cast<OpResult>(getCall()), calls);
184 }
185 
186 //===----------------------------------------------------------------------===//
187 // LoopPeelOp
188 //===----------------------------------------------------------------------===//
189 
191 transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
192  scf::ForOp target,
194  transform::TransformState &state) {
195  scf::ForOp result;
196  LogicalResult status =
197  scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
198  if (failed(status)) {
199  DiagnosedSilenceableFailure diag = emitSilenceableError()
200  << "failed to peel";
201  return diag;
202  }
203  results.push_back(target);
204  results.push_back(result);
205 
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // LoopPipelineOp
211 //===----------------------------------------------------------------------===//
212 
213 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an
214 /// operation to its logical time position given the iteration interval and the
215 /// read latency. The latter is only relevant for vector transfers.
216 static void
217 loopScheduling(scf::ForOp forOp,
218  std::vector<std::pair<Operation *, unsigned>> &schedule,
219  unsigned iterationInterval, unsigned readLatency) {
220  auto getLatency = [&](Operation *op) -> unsigned {
221  if (isa<vector::TransferReadOp>(op))
222  return readLatency;
223  return 1;
224  };
225 
227  std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
228  for (Operation &op : forOp.getBody()->getOperations()) {
229  if (isa<scf::YieldOp>(op))
230  continue;
231  unsigned earlyCycle = 0;
232  for (Value operand : op.getOperands()) {
233  Operation *def = operand.getDefiningOp();
234  if (!def)
235  continue;
236  earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
237  }
238  opCycles[&op] = earlyCycle;
239  wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
240  }
241  for (const auto &it : wrappedSchedule) {
242  for (Operation *op : it.second) {
243  unsigned cycle = opCycles[op];
244  schedule.emplace_back(op, cycle / iterationInterval);
245  }
246  }
247 }
248 
250 transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter,
251  scf::ForOp target,
253  transform::TransformState &state) {
255  options.getScheduleFn =
256  [this](scf::ForOp forOp,
257  std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
258  loopScheduling(forOp, schedule, getIterationInterval(),
259  getReadLatency());
260  };
261  scf::ForLoopPipeliningPattern pattern(options, target->getContext());
262  rewriter.setInsertionPoint(target);
263  FailureOr<scf::ForOp> patternResult =
264  scf::pipelineForLoop(rewriter, target, options);
265  if (succeeded(patternResult)) {
266  results.push_back(*patternResult);
268  }
269  return emitDefaultSilenceableFailure(target);
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // LoopPromoteIfOneIterationOp
274 //===----------------------------------------------------------------------===//
275 
276 DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne(
277  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
279  transform::TransformState &state) {
280  (void)target.promoteIfSingleIteration(rewriter);
282 }
283 
284 void transform::LoopPromoteIfOneIterationOp::getEffects(
286  consumesHandle(getTarget(), effects);
287  modifiesPayload(effects);
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // LoopUnrollOp
292 //===----------------------------------------------------------------------===//
293 
295 transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
296  Operation *op,
298  transform::TransformState &state) {
299  LogicalResult result(failure());
300  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
301  result = loopUnrollByFactor(scfFor, getFactor());
302  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
303  result = loopUnrollByFactor(affineFor, getFactor());
304 
305  if (failed(result)) {
306  DiagnosedSilenceableFailure diag = emitSilenceableError()
307  << "failed to unroll";
308  return diag;
309  }
311 }
312 
313 //===----------------------------------------------------------------------===//
314 // LoopCoalesceOp
315 //===----------------------------------------------------------------------===//
316 
318 transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
319  Operation *op,
321  transform::TransformState &state) {
322  LogicalResult result(failure());
323  if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
324  result = coalescePerfectlyNestedLoops(scfForOp);
325  else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
326  result = coalescePerfectlyNestedLoops(affineForOp);
327 
328  results.push_back(op);
329  if (failed(result)) {
330  DiagnosedSilenceableFailure diag = emitSilenceableError()
331  << "failed to coalesce";
332  return diag;
333  }
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // TakeAssumedBranchOp
339 //===----------------------------------------------------------------------===//
340 /// Replaces the given op with the contents of the given single-block region,
341 /// using the operands of the block terminator to replace operation results.
342 static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
343  Region &region) {
344  assert(llvm::hasSingleElement(region) && "expected single-region block");
345  Block *block = &region.front();
346  Operation *terminator = block->getTerminator();
347  ValueRange results = terminator->getOperands();
348  rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{});
349  rewriter.replaceOp(op, results);
350  rewriter.eraseOp(terminator);
351 }
352 
353 DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
354  transform::TransformRewriter &rewriter, scf::IfOp ifOp,
356  transform::TransformState &state) {
357  rewriter.setInsertionPoint(ifOp);
358  Region &region =
359  getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
360  if (!llvm::hasSingleElement(region)) {
361  return emitDefiniteFailure()
362  << "requires an scf.if op with a single-block "
363  << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
364  }
365  replaceOpWithRegion(rewriter, ifOp, region);
367 }
368 
369 void transform::TakeAssumedBranchOp::getEffects(
371  onlyReadsHandle(getTarget(), effects);
372  modifiesPayload(effects);
373 }
374 
375 //===----------------------------------------------------------------------===//
376 // LoopFuseSibling
377 //===----------------------------------------------------------------------===//
378 
379 /// Check if `target` and `source` are siblings, in the context that `target`
380 /// is being fused into `source`.
381 ///
382 /// This is a simple check that just checks if both operations are in the same
383 /// block and some checks to ensure that the fused IR does not violate
384 /// dominance.
386  Operation *source) {
387  // Check if both operations are same.
388  if (target == source)
389  return emitSilenceableFailure(source)
390  << "target and source need to be different loops";
391 
392  // Check if both operations are in the same block.
393  if (target->getBlock() != source->getBlock())
394  return emitSilenceableFailure(source)
395  << "target and source are not in the same block";
396 
397  // Check if fusion will violate dominance.
398  DominanceInfo domInfo(source);
399  if (target->isBeforeInBlock(source)) {
400  // Since, `target` is before `source`, all users of results of `target`
401  // need to be dominated by `source`.
402  for (Operation *user : target->getUsers()) {
403  if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
404  return emitSilenceableFailure(target)
405  << "user of results of target should be properly dominated by "
406  "source";
407  }
408  }
409  } else {
410  // Since `target` is after `source`, all values used by `target` need
411  // to dominate `source`.
412 
413  // Check if operands of `target` are dominated by `source`.
414  for (Value operand : target->getOperands()) {
415  Operation *operandOp = operand.getDefiningOp();
416  // If operand does not have a defining operation, it is a block arguement,
417  // which will always dominate `source`, since `target` and `source` are in
418  // the same block and the operand dominated `source` before.
419  if (!operandOp)
420  continue;
421 
422  // Operand's defining operation should properly dominate `source`.
423  if (!domInfo.properlyDominates(operandOp, source,
424  /*enclosingOpOk=*/false))
425  return emitSilenceableFailure(target)
426  << "operands of target should be properly dominated by source";
427  }
428 
429  // Check if values used by `target` are dominated by `source`.
430  bool failed = false;
431  OpOperand *failedValue = nullptr;
432  visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
433  if (!domInfo.properlyDominates(operand->getOwner(), source,
434  /*enclosingOpOk=*/false)) {
435  failed = true;
436  failedValue = operand;
437  }
438  });
439 
440  if (failed)
441  return emitSilenceableFailure(failedValue->getOwner())
442  << "values used inside regions of target should be properly "
443  "dominated by source";
444  }
445 
447 }
448 
449 /// Check if `target` can be fused into `source`.
450 ///
451 /// This is a simple check that just checks if both loops have same
452 /// bounds, steps and mapping. This check does not ensure that the side effects
453 /// of `target` are independent of `source` or vice-versa. It is the
454 /// responsibility of the caller to ensure that.
456  Operation *source) {
457  auto targetOp = dyn_cast<scf::ForallOp>(target);
458  auto sourceOp = dyn_cast<scf::ForallOp>(source);
459  if (!targetOp || !sourceOp)
460  return false;
461 
462  return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
463  targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
464  targetOp.getMixedStep() == sourceOp.getMixedStep() &&
465  targetOp.getMapping() == sourceOp.getMapping();
466 }
467 
468 /// Fuse `target` into `source` assuming they are siblings and indepndent.
469 /// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
470 static Operation *fuseSiblings(Operation *target, Operation *source,
471  RewriterBase &rewriter) {
472  auto targetOp = dyn_cast<scf::ForallOp>(target);
473  auto sourceOp = dyn_cast<scf::ForallOp>(source);
474  if (!targetOp || !sourceOp)
475  return nullptr;
476  return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
477 }
478 
480 transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
482  transform::TransformState &state) {
483  auto targetOps = state.getPayloadOps(getTarget());
484  auto sourceOps = state.getPayloadOps(getSource());
485 
486  if (!llvm::hasSingleElement(targetOps) ||
487  !llvm::hasSingleElement(sourceOps)) {
488  return emitDefiniteFailure()
489  << "requires exactly one target handle (got "
490  << llvm::range_size(targetOps) << ") and exactly one "
491  << "source handle (got " << llvm::range_size(sourceOps) << ")";
492  }
493 
494  Operation *target = *targetOps.begin();
495  Operation *source = *sourceOps.begin();
496 
497  // Check if the target and source are siblings.
498  DiagnosedSilenceableFailure diag = isOpSibling(target, source);
499  if (!diag.succeeded())
500  return diag;
501 
502  // Check if the target can be fused into source.
503  if (!isForallWithIdenticalConfiguration(target, source)) {
504  return emitSilenceableFailure(target->getLoc())
505  << "operations cannot be fused";
506  }
507 
508  Operation *fusedLoop = fuseSiblings(target, source, rewriter);
509  assert(fusedLoop && "failed to fuse operations");
510 
511  results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // Transform op registration
517 //===----------------------------------------------------------------------===//
518 
519 namespace {
520 class SCFTransformDialectExtension
522  SCFTransformDialectExtension> {
523 public:
524  using Base::Base;
525 
526  void init() {
527  declareGeneratedDialect<affine::AffineDialect>();
528  declareGeneratedDialect<func::FuncDialect>();
529 
530  registerTransformOps<
531 #define GET_OP_LIST
532 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
533  >();
534  }
535 };
536 } // namespace
537 
538 #define GET_OP_CLASSES
539 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
540 
542  registry.addExtensions<SCFTransformDialectExtension>();
543 }
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static DiagnosedSilenceableFailure isOpSibling(Operation *target, Operation *source)
Check if target and source are siblings, in the context that target is being fused into source.
static void loopScheduling(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &schedule, unsigned iterationInterval, unsigned readLatency)
Callback for PipeliningOption.
static bool isForallWithIdenticalConfiguration(Operation *target, Operation *source)
Check if target can be fused into source.
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, Operation *op)
Wraps the given operation op into an scf.execute_region operation.
static Operation * fuseSiblings(Operation *target, Operation *source, RewriterBase &rewriter)
Fuse target into source assuming they are siblings and indepndent.
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
This class describes a specific conversion target.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
Definition: Dominance.h:121
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:134
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:561
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:852
result_range getResults()
Definition: Operation.h:410
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
Block & front()
Definition: Region.h:65
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Type conversion class.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, bool cleanUpUnroll=false)
Unrolls this for operation by the specified unroll factor.
Definition: LoopUtils.cpp:1016
LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op)
Walk either an scf.for or an affine.for to find a band to coalesce.
Definition: LoopUtils.h:304
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void registerTransformDialectExtension(DialectRegistry &registry)
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void populateSCFStructuralTypeConversions(TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition: Utils.cpp:120
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition: Utils.cpp:906
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:36
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options to dictate how loops should be pipelined.
Definition: Transforms.h:104