MLIR  22.0.0git
SCFTransformOps.cpp
Go to the documentation of this file.
1 //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/OpDefinition.h"
26 
27 using namespace mlir;
28 using namespace mlir::affine;
29 
30 //===----------------------------------------------------------------------===//
31 // Apply...PatternsOp
32 //===----------------------------------------------------------------------===//
33 
34 void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns(
37 }
38 
39 void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns(
40  TypeConverter &typeConverter, RewritePatternSet &patterns) {
42 }
43 
44 void transform::ApplySCFStructuralConversionPatternsOp::
45  populateConversionTargetRules(const TypeConverter &typeConverter,
46  ConversionTarget &conversionTarget) {
48  conversionTarget);
49 }
50 
51 void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
52  TypeConverter &typeConverter, RewritePatternSet &patterns) {
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // ForallToForOp
58 //===----------------------------------------------------------------------===//
59 
61 transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
64  auto payload = state.getPayloadOps(getTarget());
65  if (!llvm::hasSingleElement(payload))
66  return emitSilenceableError() << "expected a single payload op";
67 
68  auto target = dyn_cast<scf::ForallOp>(*payload.begin());
69  if (!target) {
71  emitSilenceableError() << "expected the payload to be scf.forall";
72  diag.attachNote((*payload.begin())->getLoc()) << "payload op";
73  return diag;
74  }
75 
76  if (!target.getOutputs().empty()) {
77  return emitSilenceableError()
78  << "unsupported shared outputs (didn't bufferize?)";
79  }
80 
81  SmallVector<OpFoldResult> lbs = target.getMixedLowerBound();
82 
83  if (getNumResults() != lbs.size()) {
85  emitSilenceableError()
86  << "op expects as many results (" << getNumResults()
87  << ") as payload has induction variables (" << lbs.size() << ")";
88  diag.attachNote(target.getLoc()) << "payload op";
89  return diag;
90  }
91 
92  SmallVector<Operation *> opResults;
93  if (failed(scf::forallToForLoop(rewriter, target, &opResults))) {
94  DiagnosedSilenceableFailure diag = emitSilenceableError()
95  << "failed to convert forall into for";
96  return diag;
97  }
98 
99  for (auto &&[i, res] : llvm::enumerate(opResults)) {
100  results.set(cast<OpResult>(getTransformed()[i]), {res});
101  }
103 }
104 
105 //===----------------------------------------------------------------------===//
106 // ForallToForOp
107 //===----------------------------------------------------------------------===//
108 
110 transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
112  transform::TransformState &state) {
113  auto payload = state.getPayloadOps(getTarget());
114  if (!llvm::hasSingleElement(payload))
115  return emitSilenceableError() << "expected a single payload op";
116 
117  auto target = dyn_cast<scf::ForallOp>(*payload.begin());
118  if (!target) {
120  emitSilenceableError() << "expected the payload to be scf.forall";
121  diag.attachNote((*payload.begin())->getLoc()) << "payload op";
122  return diag;
123  }
124 
125  if (!target.getOutputs().empty()) {
126  return emitSilenceableError()
127  << "unsupported shared outputs (didn't bufferize?)";
128  }
129 
130  if (getNumResults() != 1) {
131  DiagnosedSilenceableFailure diag = emitSilenceableError()
132  << "op expects one result, given "
133  << getNumResults();
134  diag.attachNote(target.getLoc()) << "payload op";
135  return diag;
136  }
137 
138  scf::ParallelOp opResult;
139  if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
141  emitSilenceableError() << "failed to convert forall into parallel";
142  return diag;
143  }
144 
145  results.set(cast<OpResult>(getTransformed()[0]), {opResult});
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // ParallelForToNestedForOps
151 //===----------------------------------------------------------------------===//
152 
153 DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
156  auto payload = state.getPayloadOps(getTarget());
157  if (!llvm::hasSingleElement(payload))
158  return emitSilenceableError() << "expected a single payload op";
159 
160  auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
161  if (!target) {
163  emitSilenceableError() << "expected the payload to be scf.parallel";
164  diag.attachNote((*payload.begin())->getLoc()) << "payload op";
165  return diag;
166  }
167 
168  if (getNumResults() != 1) {
169  DiagnosedSilenceableFailure diag = emitSilenceableError()
170  << "op expects one result, given "
171  << getNumResults();
172  diag.attachNote(target.getLoc()) << "payload op";
173  return diag;
174  }
175 
176  FailureOr<scf::LoopNest> loopNest =
177  scf::parallelForToNestedFors(rewriter, target);
178  if (failed(loopNest)) {
180  emitSilenceableError() << "failed to convert parallel into nested fors";
181  return diag;
182  }
183 
184  results.set(cast<OpResult>(getTransformed()[0]), {loopNest->loops.front()});
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // LoopOutlineOp
190 //===----------------------------------------------------------------------===//
191 
192 /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses
193 /// the provided rewriter for all operations to remain compatible with the
194 /// rewriting infra, as opposed to just splicing the op in place.
195 static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
196  Operation *op) {
197  if (op->getNumRegions() != 1)
198  return nullptr;
200  b.setInsertionPoint(op);
201  scf::ExecuteRegionOp executeRegionOp =
202  scf::ExecuteRegionOp::create(b, op->getLoc(), op->getResultTypes());
203  {
205  b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
206  Operation *clonedOp = b.cloneWithoutRegions(*op);
207  Region &clonedRegion = clonedOp->getRegions().front();
208  assert(clonedRegion.empty() && "expected empty region");
209  b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
210  clonedRegion.end());
211  scf::YieldOp::create(b, op->getLoc(), clonedOp->getResults());
212  }
213  b.replaceOp(op, executeRegionOp.getResults());
214  return executeRegionOp;
215 }
216 
218 transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter,
220  transform::TransformState &state) {
221  SmallVector<Operation *> functions;
224  for (Operation *target : state.getPayloadOps(getTarget())) {
225  Location location = target->getLoc();
226  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target);
227  scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target);
228  if (!exec) {
229  DiagnosedSilenceableFailure diag = emitSilenceableError()
230  << "failed to outline";
231  diag.attachNote(target->getLoc()) << "target op";
232  return diag;
233  }
234  func::CallOp call;
235  FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion(
236  rewriter, location, exec.getRegion(), getFuncName(), &call);
237 
238  if (failed(outlined))
239  return emitDefaultDefiniteFailure(target);
240 
241  if (symbolTableOp) {
242  SymbolTable &symbolTable =
243  symbolTables.try_emplace(symbolTableOp, symbolTableOp)
244  .first->getSecond();
245  symbolTable.insert(*outlined);
246  call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined));
247  }
248  functions.push_back(*outlined);
249  calls.push_back(call);
250  }
251  results.set(cast<OpResult>(getFunction()), functions);
252  results.set(cast<OpResult>(getCall()), calls);
254 }
255 
256 //===----------------------------------------------------------------------===//
257 // LoopPeelOp
258 //===----------------------------------------------------------------------===//
259 
261 transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter,
262  scf::ForOp target,
264  transform::TransformState &state) {
265  scf::ForOp result;
266  if (getPeelFront()) {
267  LogicalResult status =
268  scf::peelForLoopFirstIteration(rewriter, target, result);
269  if (failed(status)) {
271  emitSilenceableError() << "failed to peel the first iteration";
272  return diag;
273  }
274  } else {
275  LogicalResult status =
276  scf::peelForLoopAndSimplifyBounds(rewriter, target, result);
277  if (failed(status)) {
278  DiagnosedSilenceableFailure diag = emitSilenceableError()
279  << "failed to peel the last iteration";
280  return diag;
281  }
282  }
283 
284  results.push_back(target);
285  results.push_back(result);
286 
288 }
289 
290 //===----------------------------------------------------------------------===//
291 // LoopPipelineOp
292 //===----------------------------------------------------------------------===//
293 
294 /// Callback for PipeliningOption. Populates `schedule` with the mapping from an
295 /// operation to its logical time position given the iteration interval and the
296 /// read latency. The latter is only relevant for vector transfers.
297 static void
298 loopScheduling(scf::ForOp forOp,
299  std::vector<std::pair<Operation *, unsigned>> &schedule,
300  unsigned iterationInterval, unsigned readLatency) {
301  auto getLatency = [&](Operation *op) -> unsigned {
302  if (isa<vector::TransferReadOp>(op))
303  return readLatency;
304  return 1;
305  };
306 
307  std::optional<int64_t> ubConstant =
308  getConstantIntValue(forOp.getUpperBound());
309  std::optional<int64_t> lbConstant =
310  getConstantIntValue(forOp.getLowerBound());
312  std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
313  for (Operation &op : forOp.getBody()->getOperations()) {
314  if (isa<scf::YieldOp>(op))
315  continue;
316  unsigned earlyCycle = 0;
317  for (Value operand : op.getOperands()) {
318  Operation *def = operand.getDefiningOp();
319  if (!def)
320  continue;
321  if (ubConstant && lbConstant) {
322  unsigned ubInt = ubConstant.value();
323  unsigned lbInt = lbConstant.value();
324  auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def));
325  earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency);
326  } else {
327  earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def));
328  }
329  }
330  opCycles[&op] = earlyCycle;
331  wrappedSchedule[earlyCycle % iterationInterval].push_back(&op);
332  }
333  for (const auto &it : wrappedSchedule) {
334  for (Operation *op : it.second) {
335  unsigned cycle = opCycles[op];
336  schedule.emplace_back(op, cycle / iterationInterval);
337  }
338  }
339 }
340 
342 transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter,
343  scf::ForOp target,
345  transform::TransformState &state) {
347  options.getScheduleFn =
348  [this](scf::ForOp forOp,
349  std::vector<std::pair<Operation *, unsigned>> &schedule) mutable {
350  loopScheduling(forOp, schedule, getIterationInterval(),
351  getReadLatency());
352  };
353  scf::ForLoopPipeliningPattern pattern(options, target->getContext());
354  rewriter.setInsertionPoint(target);
355  FailureOr<scf::ForOp> patternResult =
356  scf::pipelineForLoop(rewriter, target, options);
357  if (succeeded(patternResult)) {
358  results.push_back(*patternResult);
360  }
361  return emitDefaultSilenceableFailure(target);
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // LoopPromoteIfOneIterationOp
366 //===----------------------------------------------------------------------===//
367 
368 DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne(
369  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
371  transform::TransformState &state) {
372  (void)target.promoteIfSingleIteration(rewriter);
374 }
375 
376 void transform::LoopPromoteIfOneIterationOp::getEffects(
378  consumesHandle(getTargetMutable(), effects);
379  modifiesPayload(effects);
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // LoopUnrollOp
384 //===----------------------------------------------------------------------===//
385 
387 transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
388  Operation *op,
390  transform::TransformState &state) {
391  LogicalResult result(failure());
392  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
393  result = loopUnrollByFactor(scfFor, getFactor());
394  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
395  result = loopUnrollByFactor(affineFor, getFactor());
396  else
397  return emitSilenceableError()
398  << "failed to unroll, incorrect type of payload";
399 
400  if (failed(result))
401  return emitSilenceableError() << "failed to unroll";
402 
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // LoopUnrollAndJamOp
408 //===----------------------------------------------------------------------===//
409 
410 DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
413  transform::TransformState &state) {
414  LogicalResult result(failure());
415  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
416  result = loopUnrollJamByFactor(scfFor, getFactor());
417  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
418  result = loopUnrollJamByFactor(affineFor, getFactor());
419  else
420  return emitSilenceableError()
421  << "failed to unroll and jam, incorrect type of payload";
422 
423  if (failed(result))
424  return emitSilenceableError() << "failed to unroll and jam";
425 
427 }
428 
429 //===----------------------------------------------------------------------===//
430 // LoopCoalesceOp
431 //===----------------------------------------------------------------------===//
432 
434 transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
435  Operation *op,
437  transform::TransformState &state) {
438  LogicalResult result(failure());
439  if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
440  result = coalescePerfectlyNestedSCFForLoops(scfForOp);
441  else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
442  result = coalescePerfectlyNestedAffineLoops(affineForOp);
443 
444  results.push_back(op);
445  if (failed(result)) {
446  DiagnosedSilenceableFailure diag = emitSilenceableError()
447  << "failed to coalesce";
448  return diag;
449  }
451 }
452 
453 //===----------------------------------------------------------------------===//
454 // TakeAssumedBranchOp
455 //===----------------------------------------------------------------------===//
456 /// Replaces the given op with the contents of the given single-block region,
457 /// using the operands of the block terminator to replace operation results.
458 static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op,
459  Region &region) {
460  assert(region.hasOneBlock() && "expected single-block region");
461  Block *block = &region.front();
462  Operation *terminator = block->getTerminator();
463  ValueRange results = terminator->getOperands();
464  rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{});
465  rewriter.replaceOp(op, results);
466  rewriter.eraseOp(terminator);
467 }
468 
469 DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne(
470  transform::TransformRewriter &rewriter, scf::IfOp ifOp,
472  transform::TransformState &state) {
473  rewriter.setInsertionPoint(ifOp);
474  Region &region =
475  getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion();
476  if (!region.hasOneBlock()) {
477  return emitDefiniteFailure()
478  << "requires an scf.if op with a single-block "
479  << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region";
480  }
481  replaceOpWithRegion(rewriter, ifOp, region);
483 }
484 
485 void transform::TakeAssumedBranchOp::getEffects(
487  onlyReadsHandle(getTargetMutable(), effects);
488  modifiesPayload(effects);
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // LoopFuseSiblingOp
493 //===----------------------------------------------------------------------===//
494 
495 /// Check if `target` and `source` are siblings, in the context that `target`
496 /// is being fused into `source`.
497 ///
498 /// This is a simple check that just checks if both operations are in the same
499 /// block and some checks to ensure that the fused IR does not violate
500 /// dominance.
502  Operation *source) {
503  // Check if both operations are same.
504  if (target == source)
505  return emitSilenceableFailure(source)
506  << "target and source need to be different loops";
507 
508  // Check if both operations are in the same block.
509  if (target->getBlock() != source->getBlock())
510  return emitSilenceableFailure(source)
511  << "target and source are not in the same block";
512 
513  // Check if fusion will violate dominance.
514  DominanceInfo domInfo(source);
515  if (target->isBeforeInBlock(source)) {
516  // Since `target` is before `source`, all users of results of `target`
517  // need to be dominated by `source`.
518  for (Operation *user : target->getUsers()) {
519  if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
520  return emitSilenceableFailure(target)
521  << "user of results of target should be properly dominated by "
522  "source";
523  }
524  }
525  } else {
526  // Since `target` is after `source`, all values used by `target` need
527  // to dominate `source`.
528 
529  // Check if operands of `target` are dominated by `source`.
530  for (Value operand : target->getOperands()) {
531  Operation *operandOp = operand.getDefiningOp();
532  // Operands without defining operations are block arguments. When `target`
533  // and `source` occur in the same block, these operands dominate `source`.
534  if (!operandOp)
535  continue;
536 
537  // Operand's defining operation should properly dominate `source`.
538  if (!domInfo.properlyDominates(operandOp, source,
539  /*enclosingOpOk=*/false))
540  return emitSilenceableFailure(target)
541  << "operands of target should be properly dominated by source";
542  }
543 
544  // Check if values used by `target` are dominated by `source`.
545  bool failed = false;
546  OpOperand *failedValue = nullptr;
547  visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
548  Operation *operandOp = operand->get().getDefiningOp();
549  if (operandOp && !domInfo.properlyDominates(operandOp, source,
550  /*enclosingOpOk=*/false)) {
551  // `operand` is not an argument of an enclosing block and the defining
552  // op of `operand` is outside `target` but does not dominate `source`.
553  failed = true;
554  failedValue = operand;
555  }
556  });
557 
558  if (failed)
559  return emitSilenceableFailure(failedValue->getOwner())
560  << "values used inside regions of target should be properly "
561  "dominated by source";
562  }
563 
565 }
566 
567 /// Check if `target` scf.forall can be fused into `source` scf.forall.
568 ///
569 /// This simply checks if both loops have the same bounds, steps and mapping.
570 /// No attempt is made at checking that the side effects of `target` and
571 /// `source` are independent of each other.
573  Operation *source) {
574  auto targetOp = dyn_cast<scf::ForallOp>(target);
575  auto sourceOp = dyn_cast<scf::ForallOp>(source);
576  if (!targetOp || !sourceOp)
577  return false;
578 
579  return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
580  targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
581  targetOp.getMixedStep() == sourceOp.getMixedStep() &&
582  targetOp.getMapping() == sourceOp.getMapping();
583 }
584 
585 /// Check if `target` scf.for can be fused into `source` scf.for.
586 ///
587 /// This simply checks if both loops have the same bounds and steps. No attempt
588 /// is made at checking that the side effects of `target` and `source` are
589 /// independent of each other.
591  Operation *source) {
592  auto targetOp = dyn_cast<scf::ForOp>(target);
593  auto sourceOp = dyn_cast<scf::ForOp>(source);
594  if (!targetOp || !sourceOp)
595  return false;
596 
597  return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
598  targetOp.getUpperBound() == sourceOp.getUpperBound() &&
599  targetOp.getStep() == sourceOp.getStep();
600 }
601 
603 transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
605  transform::TransformState &state) {
606  auto targetOps = state.getPayloadOps(getTarget());
607  auto sourceOps = state.getPayloadOps(getSource());
608 
609  if (!llvm::hasSingleElement(targetOps) ||
610  !llvm::hasSingleElement(sourceOps)) {
611  return emitDefiniteFailure()
612  << "requires exactly one target handle (got "
613  << llvm::range_size(targetOps) << ") and exactly one "
614  << "source handle (got " << llvm::range_size(sourceOps) << ")";
615  }
616 
617  Operation *target = *targetOps.begin();
618  Operation *source = *sourceOps.begin();
619 
620  // Check if the target and source are siblings.
621  DiagnosedSilenceableFailure diag = isOpSibling(target, source);
622  if (!diag.succeeded())
623  return diag;
624 
625  Operation *fusedLoop;
626  /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
627  if (isForWithIdenticalConfiguration(target, source)) {
628  fusedLoop = fuseIndependentSiblingForLoops(
629  cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
630  } else if (isForallWithIdenticalConfiguration(target, source)) {
632  cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
633  } else {
634  return emitSilenceableFailure(target->getLoc())
635  << "operations cannot be fused";
636  }
637 
638  assert(fusedLoop && "failed to fuse operations");
639 
640  results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // Transform op registration
646 //===----------------------------------------------------------------------===//
647 
648 namespace {
649 class SCFTransformDialectExtension
651  SCFTransformDialectExtension> {
652 public:
653  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension)
654 
655  using Base::Base;
656 
657  void init() {
658  declareGeneratedDialect<affine::AffineDialect>();
659  declareGeneratedDialect<func::FuncDialect>();
660 
661  registerTransformOps<
662 #define GET_OP_LIST
663 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
664  >();
665  }
666 };
667 } // namespace
668 
669 #define GET_OP_CLASSES
670 #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc"
671 
673  registry.addExtensions<SCFTransformDialectExtension>();
674 }
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool isForWithIdenticalConfiguration(Operation *target, Operation *source)
Check if target scf.for can be fused into source scf.for.
static DiagnosedSilenceableFailure isOpSibling(Operation *target, Operation *source)
Check if target and source are siblings, in the context that target is being fused into source.
static void loopScheduling(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &schedule, unsigned iterationInterval, unsigned readLatency)
Callback for PipeliningOption.
static bool isForallWithIdenticalConfiguration(Operation *target, Operation *source)
Check if target scf.forall can be fused into source scf.forall.
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, Operation *op)
Wraps the given operation op into an scf.execute_region operation.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
This class describes a specific conversion target.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:323
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Operation * cloneWithoutRegions(Operation &op, IRMapping &mapper)
Creates a deep copy of this operation but keep the operation regions empty.
Definition: Builders.h:587
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getResults()
Definition: Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
iterator end()
Definition: Region.h:56
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
StringAttr insert(Operation *symbol, Block::iterator insertPt={})
Insert a new symbol into the table, and rename it as necessary to avoid collisions.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Type conversion class.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr, bool cleanUpUnroll=false)
Unrolls this for operation by the specified unroll factor.
Definition: LoopUtils.cpp:995
LogicalResult loopUnrollJamByFactor(AffineForOp forOp, uint64_t unrollJamFactor)
Unrolls and jams this loop by the specified factor.
Definition: LoopUtils.cpp:1084
LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op)
Walk an affine.for to find a band to coalesce.
Definition: LoopUtils.cpp:2769
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void registerTransformDialectExtension(DialectRegistry &registry)
FailureOr< scf::LoopNest > parallelForToNestedFors(RewriterBase &rewriter, ParallelOp parallelOp)
Try converting scf.forall into an scf.parallel loop.
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp, SmallVectorImpl< Operation * > *results=nullptr)
Try converting scf.forall into a set of nested scf.for loops.
LogicalResult peelForLoopFirstIteration(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Peel the first iteration out of the scf.for loop.
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp, ParallelOp *result=nullptr)
Try converting scf.forall into an scf.parallel loop.
void populateSCFStructuralTypeConversions(const TypeConverter &typeConverter, RewritePatternSet &patterns)
Similar to populateSCFStructuralTypeConversionsAndLegality but does not populate the conversion targe...
void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns)
Populate patterns for canonicalizing operations inside SCF loop bodies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void populateSCFStructuralTypeConversionTarget(const TypeConverter &typeConverter, ConversionTarget &target)
Updates the ConversionTarget with dynamic legality of SCF operations based on the provided type conve...
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
Definition: Utils.cpp:986
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
const FrozenRewritePatternSet & patterns
void populateSCFToControlFlowConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert SCF operations to CFG branch-based operations within the Control...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region &region, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
Definition: Utils.cpp:113
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
Definition: Utils.cpp:1365
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
Definition: Utils.cpp:1418
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:43
Options to dictate how loops should be pipelined.
Definition: Transforms.h:129