MLIR  16.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/PatternMatch.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "transform-dialect"
22 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
23 
24 using namespace mlir;
25 
29  types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
30  return success();
31 }
32 
34  ValueRange) {}
35 
36 #define GET_OP_CLASSES
37 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
38 
39 //===----------------------------------------------------------------------===//
40 // PatternApplicatorExtension
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 /// A simple pattern rewriter that can be constructed from a context. This is
45 /// necessary to apply patterns to a specific op locally.
46 class TrivialPatternRewriter : public PatternRewriter {
47 public:
48  explicit TrivialPatternRewriter(MLIRContext *context)
49  : PatternRewriter(context) {}
50 };
51 
52 /// A TransformState extension that keeps track of compiled PDL pattern sets.
53 /// This is intended to be used along the WithPDLPatterns op. The extension
54 /// can be constructed given an operation that has a SymbolTable trait and
55 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
56 /// by one when requested; this behavior is subject to change.
57 class PatternApplicatorExtension : public transform::TransformState::Extension {
58 public:
59  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
60 
61  /// Creates the extension for patterns contained in `patternContainer`.
62  explicit PatternApplicatorExtension(transform::TransformState &state,
63  Operation *patternContainer)
64  : Extension(state), patterns(patternContainer) {}
65 
66  /// Appends to `results` the operations contained in `root` that matched the
67  /// PDL pattern with the given name. Note that `root` may or may not be the
68  /// operation that contains PDL patterns. Reports an error if the pattern
69  /// cannot be found. Note that when no operations are matched, this still
70  /// succeeds as long as the pattern exists.
71  LogicalResult findAllMatches(StringRef patternName, Operation *root,
73 
74 private:
75  /// Map from the pattern name to a singleton set of rewrite patterns that only
76  /// contains the pattern with this name. Populated when the pattern is first
77  /// requested.
78  // TODO: reconsider the efficiency of this storage when more usage data is
79  // available. Storing individual patterns in a set and triggering compilation
80  // for each of them has overhead. So does compiling a large set of patterns
81  // only to apply a handlful of them.
82  llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
83 
84  /// A symbol table operation containing the relevant PDL patterns.
85  SymbolTable patterns;
86 };
87 
88 LogicalResult PatternApplicatorExtension::findAllMatches(
89  StringRef patternName, Operation *root,
91  auto it = compiledPatterns.find(patternName);
92  if (it == compiledPatterns.end()) {
93  auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
94  if (!patternOp)
95  return failure();
96 
97  OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
98  patternOp->moveBefore(pdlModuleOp->getBody(),
99  pdlModuleOp->getBody()->end());
100  PDLPatternModule patternModule(std::move(pdlModuleOp));
101 
102  // Merge in the hooks owned by the dialect. Make a copy as they may be
103  // also used by the following operations.
104  auto *dialect =
105  root->getContext()->getLoadedDialect<transform::TransformDialect>();
106  for (const auto &pair : dialect->getPDLConstraintHooks())
107  patternModule.registerConstraintFunction(pair.first(), pair.second);
108 
109  // Register a noop rewriter because PDL requires patterns to end with some
110  // rewrite call.
111  patternModule.registerRewriteFunction(
112  "transform.dialect", [](PatternRewriter &, Operation *) {});
113 
114  it = compiledPatterns
115  .try_emplace(patternOp.getName(), std::move(patternModule))
116  .first;
117  }
118 
119  PatternApplicator applicator(it->second);
120  TrivialPatternRewriter rewriter(root->getContext());
121  applicator.applyDefaultCostModel();
122  root->walk([&](Operation *op) {
123  if (succeeded(applicator.matchAndRewrite(op, rewriter)))
124  results.push_back(op);
125  });
126 
127  return success();
128 }
129 } // namespace
130 
131 //===----------------------------------------------------------------------===//
132 // AlternativesOp
133 //===----------------------------------------------------------------------===//
134 
136 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) {
137  if (index && getOperation()->getNumOperands() == 1)
138  return getOperation()->getOperands();
139  return OperandRange(getOperation()->operand_end(),
140  getOperation()->operand_end());
141 }
142 
143 void transform::AlternativesOp::getSuccessorRegions(
144  Optional<unsigned> index, ArrayRef<Attribute> operands,
146  for (Region &alternative : llvm::drop_begin(
147  getAlternatives(), index.has_value() ? *index + 1 : 0)) {
148  regions.emplace_back(&alternative, !getOperands().empty()
149  ? alternative.getArguments()
151  }
152  if (index.has_value())
153  regions.emplace_back(getOperation()->getResults());
154 }
155 
156 void transform::AlternativesOp::getRegionInvocationBounds(
158  (void)operands;
159  // The region corresponding to the first alternative is always executed, the
160  // remaining may or may not be executed.
161  bounds.reserve(getNumRegions());
162  bounds.emplace_back(1, 1);
163  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
164 }
165 
166 static void forwardTerminatorOperands(Block *block,
168  transform::TransformResults &results) {
169  for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
170  block->getParentOp()->getOpResults())) {
171  Value terminatorOperand = std::get<0>(pair);
172  OpResult result = std::get<1>(pair);
173  results.set(result, state.getPayloadOps(terminatorOperand));
174  }
175 }
176 
178 transform::AlternativesOp::apply(transform::TransformResults &results,
179  transform::TransformState &state) {
180  SmallVector<Operation *> originals;
181  if (Value scopeHandle = getScope())
182  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
183  else
184  originals.push_back(state.getTopLevel());
185 
186  for (Operation *original : originals) {
187  if (original->isAncestor(getOperation())) {
189  emitError() << "scope must not contain the transforms being applied";
190  diag.attachNote(original->getLoc()) << "scope";
192  }
193  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
195  emitError()
196  << "only isolated-from-above ops can be alternative scopes";
197  diag.attachNote(original->getLoc()) << "scope";
198  return DiagnosedSilenceableFailure(std::move(diag));
199  }
200  }
201 
202  for (Region &reg : getAlternatives()) {
203  // Clone the scope operations and make the transforms in this alternative
204  // region apply to them by virtue of mapping the block argument (the only
205  // visible handle) to the cloned scope operations. This effectively prevents
206  // the transformation from accessing any IR outside the scope.
207  auto scope = state.make_region_scope(reg);
208  auto clones = llvm::to_vector(
209  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
210  auto deleteClones = llvm::make_scope_exit([&] {
211  for (Operation *clone : clones)
212  clone->erase();
213  });
214  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
216 
217  bool failed = false;
218  for (Operation &transform : reg.front().without_terminator()) {
220  state.applyTransform(cast<TransformOpInterface>(transform));
221  if (result.isSilenceableFailure()) {
222  LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
223  << "\n");
224  failed = true;
225  break;
226  }
227 
228  if (::mlir::failed(result.silence()))
230  }
231 
232  // If all operations in the given alternative succeeded, no need to consider
233  // the rest. Replace the original scoping operation with the clone on which
234  // the transformations were performed.
235  if (!failed) {
236  // We will be using the clones, so cancel their scheduled deletion.
237  deleteClones.release();
238  IRRewriter rewriter(getContext());
239  for (const auto &kvp : llvm::zip(originals, clones)) {
240  Operation *original = std::get<0>(kvp);
241  Operation *clone = std::get<1>(kvp);
242  original->getBlock()->getOperations().insert(original->getIterator(),
243  clone);
244  rewriter.replaceOp(original, clone->getResults());
245  }
246  forwardTerminatorOperands(&reg.front(), state, results);
248  }
249  }
250  return emitSilenceableError() << "all alternatives failed";
251 }
252 
254  for (Region &alternative : getAlternatives()) {
255  Block &block = alternative.front();
256  if (block.getNumArguments() != 1 ||
257  !block.getArgument(0).getType().isa<pdl::OperationType>()) {
258  return emitOpError()
259  << "expects region blocks to have one operand of type "
260  << pdl::OperationType::get(getContext());
261  }
262 
263  Operation *terminator = block.getTerminator();
264  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
265  InFlightDiagnostic diag = emitOpError()
266  << "expects terminator operands to have the "
267  "same type as results of the operation";
268  diag.attachNote(terminator->getLoc()) << "terminator";
269  return diag;
270  }
271  }
272 
273  return success();
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // ForeachOp
278 //===----------------------------------------------------------------------===//
279 
281 transform::ForeachOp::apply(transform::TransformResults &results,
282  transform::TransformState &state) {
283  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
284  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
285 
286  for (Operation *op : payloadOps) {
287  auto scope = state.make_region_scope(getBody());
288  if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
290 
291  // Execute loop body.
292  for (Operation &transform : getBody().front().without_terminator()) {
294  cast<transform::TransformOpInterface>(transform));
295  if (!result.succeeded())
296  return result;
297  }
298 
299  // Append yielded payload ops to result list (if any).
300  for (unsigned i = 0; i < getNumResults(); ++i) {
301  ArrayRef<Operation *> yieldedOps =
302  state.getPayloadOps(getYieldOp().getOperand(i));
303  resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
304  }
305  }
306 
307  for (unsigned i = 0; i < getNumResults(); ++i)
308  results.set(getResult(i).cast<OpResult>(), resultOps[i]);
309 
311 }
312 
313 void transform::ForeachOp::getEffects(
315  BlockArgument iterVar = getIterationVariable();
316  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
317  return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
318  })) {
319  consumesHandle(getTarget(), effects);
320  } else {
321  onlyReadsHandle(getTarget(), effects);
322  }
323 
324  for (Value result : getResults())
325  producesHandle(result, effects);
326 }
327 
328 void transform::ForeachOp::getSuccessorRegions(
329  Optional<unsigned> index, ArrayRef<Attribute> operands,
331  Region *bodyRegion = &getBody();
332  if (!index) {
333  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
334  return;
335  }
336 
337  // Branch back to the region or the parent.
338  assert(*index == 0 && "unexpected region index");
339  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
340  regions.emplace_back();
341 }
342 
344 transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
345  // The iteration variable op handle is mapped to a subset (one op to be
346  // precise) of the payload ops of the ForeachOp operand.
347  assert(index && *index == 0 && "unexpected region index");
348  return getOperation()->getOperands();
349 }
350 
351 transform::YieldOp transform::ForeachOp::getYieldOp() {
352  return cast<transform::YieldOp>(getBody().front().getTerminator());
353 }
354 
356  auto yieldOp = getYieldOp();
357  if (getNumResults() != yieldOp.getNumOperands())
358  return emitOpError() << "expects the same number of results as the "
359  "terminator has operands";
360  for (Value v : yieldOp.getOperands())
361  if (!v.getType().isa<pdl::OperationType>())
362  return yieldOp->emitOpError("expects only PDL_Operation operands");
363  return success();
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // GetClosestIsolatedParentOp
368 //===----------------------------------------------------------------------===//
369 
370 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
372  SetVector<Operation *> parents;
373  for (Operation *target : state.getPayloadOps(getTarget())) {
374  Operation *parent =
376  if (!parent) {
378  emitSilenceableError()
379  << "could not find an isolated-from-above parent op";
380  diag.attachNote(target->getLoc()) << "target op";
381  return diag;
382  }
383  parents.insert(parent);
384  }
385  results.set(getResult().cast<OpResult>(), parents.getArrayRef());
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // MergeHandlesOp
391 //===----------------------------------------------------------------------===//
392 
394 transform::MergeHandlesOp::apply(transform::TransformResults &results,
395  transform::TransformState &state) {
396  SmallVector<Operation *> operations;
397  for (Value operand : getHandles())
398  llvm::append_range(operations, state.getPayloadOps(operand));
399  if (!getDeduplicate()) {
400  results.set(getResult().cast<OpResult>(), operations);
402  }
403 
404  SetVector<Operation *> uniqued(operations.begin(), operations.end());
405  results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
407 }
408 
409 void transform::MergeHandlesOp::getEffects(
411  consumesHandle(getHandles(), effects);
412  producesHandle(getResult(), effects);
413 
414  // There are no effects on the Payload IR as this is only a handle
415  // manipulation.
416 }
417 
418 OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
419  if (getDeduplicate() || getHandles().size() != 1)
420  return {};
421 
422  // If deduplication is not required and there is only one operand, it can be
423  // used directly instead of merging.
424  return getHandles().front();
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // PDLMatchOp
429 //===----------------------------------------------------------------------===//
430 
432 transform::PDLMatchOp::apply(transform::TransformResults &results,
433  transform::TransformState &state) {
434  auto *extension = state.getExtension<PatternApplicatorExtension>();
435  assert(extension &&
436  "expected PatternApplicatorExtension to be attached by the parent op");
437  SmallVector<Operation *> targets;
438  for (Operation *root : state.getPayloadOps(getRoot())) {
439  if (failed(extension->findAllMatches(
440  getPatternName().getLeafReference().getValue(), root, targets))) {
441  emitOpError() << "could not find pattern '" << getPatternName() << "'";
443  }
444  }
445  results.set(getResult().cast<OpResult>(), targets);
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // ReplicateOp
451 //===----------------------------------------------------------------------===//
452 
454 transform::ReplicateOp::apply(transform::TransformResults &results,
455  transform::TransformState &state) {
456  unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
457  for (const auto &en : llvm::enumerate(getHandles())) {
458  Value handle = en.value();
459  ArrayRef<Operation *> current = state.getPayloadOps(handle);
460  SmallVector<Operation *> payload;
461  payload.reserve(numRepetitions * current.size());
462  for (unsigned i = 0; i < numRepetitions; ++i)
463  llvm::append_range(payload, current);
464  results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
465  }
467 }
468 
469 void transform::ReplicateOp::getEffects(
471  onlyReadsHandle(getPattern(), effects);
472  consumesHandle(getHandles(), effects);
473  producesHandle(getReplicated(), effects);
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // SequenceOp
478 //===----------------------------------------------------------------------===//
479 
481 transform::SequenceOp::apply(transform::TransformResults &results,
482  transform::TransformState &state) {
483  // Map the entry block argument to the list of operations.
484  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
485  if (failed(mapBlockArguments(state)))
487 
488  // Apply the sequenced ops one by one.
489  for (Operation &transform : getBodyBlock()->without_terminator()) {
491  state.applyTransform(cast<TransformOpInterface>(transform));
492  if (result.isDefiniteFailure())
493  return result;
494 
495  if (result.isSilenceableFailure()) {
496  if (getFailurePropagationMode() == FailurePropagationMode::Propagate)
497  return result;
498  (void)result.silence();
499  }
500  }
501 
502  // Forward the operation mapping for values yielded from the sequence to the
503  // values produced by the sequence op.
504  forwardTerminatorOperands(getBodyBlock(), state, results);
506 }
507 
508 /// Returns `true` if the given op operand may be consuming the handle value in
509 /// the Transform IR. That is, if it may have a Free effect on it.
511  // Conservatively assume the effect being present in absence of the interface.
512  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
513  if (!iface)
514  return true;
515 
516  return isHandleConsumed(use.get(), iface);
517 }
518 
521  function_ref<InFlightDiagnostic()> reportError) {
522  OpOperand *potentialConsumer = nullptr;
523  for (OpOperand &use : value.getUses()) {
524  if (!isValueUsePotentialConsumer(use))
525  continue;
526 
527  if (!potentialConsumer) {
528  potentialConsumer = &use;
529  continue;
530  }
531 
532  InFlightDiagnostic diag = reportError()
533  << " has more than one potential consumer";
534  diag.attachNote(potentialConsumer->getOwner()->getLoc())
535  << "used here as operand #" << potentialConsumer->getOperandNumber();
536  diag.attachNote(use.getOwner()->getLoc())
537  << "used here as operand #" << use.getOperandNumber();
538  return diag;
539  }
540 
541  return success();
542 }
543 
545  // Check if the block argument has more than one consuming use.
546  for (BlockArgument argument : getBodyBlock()->getArguments()) {
547  auto report = [&]() {
548  return (emitOpError() << "block argument #" << argument.getArgNumber());
549  };
550  if (failed(checkDoubleConsume(argument, report)))
551  return failure();
552  }
553 
554  // Check properties of the nested operations they cannot check themselves.
555  for (Operation &child : *getBodyBlock()) {
556  if (!isa<TransformOpInterface>(child) &&
557  &child != &getBodyBlock()->back()) {
559  emitOpError()
560  << "expected children ops to implement TransformOpInterface";
561  diag.attachNote(child.getLoc()) << "op without interface";
562  return diag;
563  }
564 
565  for (OpResult result : child.getResults()) {
566  auto report = [&]() {
567  return (child.emitError() << "result #" << result.getResultNumber());
568  };
569  if (failed(checkDoubleConsume(result, report)))
570  return failure();
571  }
572  }
573 
574  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
575  getOperation()->getResultTypes()) {
576  InFlightDiagnostic diag = emitOpError()
577  << "expects the types of the terminator operands "
578  "to match the types of the result";
579  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
580  return diag;
581  }
582  return success();
583 }
584 
585 void transform::SequenceOp::getEffects(
587  auto *mappingResource = TransformMappingResource::get();
588  effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
589 
590  for (Value result : getResults()) {
591  effects.emplace_back(MemoryEffects::Allocate::get(), result,
592  mappingResource);
593  effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
594  }
595 
596  if (!getRoot()) {
597  for (Operation &op : *getBodyBlock()) {
598  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
599  if (!iface) {
600  // TODO: fill all possible effects; or require ops to actually implement
601  // the memory effect interface always
602  assert(false);
603  }
604 
606  iface.getEffects(effects);
607  }
608  return;
609  }
610 
611  // Carry over all effects on the argument of the entry block as those on the
612  // operand, this is the same value just remapped.
613  for (Operation &op : *getBodyBlock()) {
614  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
615  if (!iface) {
616  // TODO: fill all possible effects; or require ops to actually implement
617  // the memory effect interface always
618  assert(false);
619  }
620 
622  iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
623  for (const auto &effect : nestedEffects)
624  effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
625  }
626 }
627 
629 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
630  assert(index && *index == 0 && "unexpected region index");
631  if (getOperation()->getNumOperands() == 1)
632  return getOperation()->getOperands();
633  return OperandRange(getOperation()->operand_end(),
634  getOperation()->operand_end());
635 }
636 
637 void transform::SequenceOp::getSuccessorRegions(
638  Optional<unsigned> index, ArrayRef<Attribute> operands,
640  if (!index) {
641  Region *bodyRegion = &getBody();
642  regions.emplace_back(bodyRegion, !operands.empty()
643  ? bodyRegion->getArguments()
645  return;
646  }
647 
648  assert(*index == 0 && "unexpected region index");
649  regions.emplace_back(getOperation()->getResults());
650 }
651 
652 void transform::SequenceOp::getRegionInvocationBounds(
654  (void)operands;
655  bounds.emplace_back(1, 1);
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // WithPDLPatternsOp
660 //===----------------------------------------------------------------------===//
661 
663 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
664  transform::TransformState &state) {
665  OwningOpRef<ModuleOp> pdlModuleOp =
666  ModuleOp::create(getOperation()->getLoc());
667  TransformOpInterface transformOp = nullptr;
668  for (Operation &nested : getBody().front()) {
669  if (!isa<pdl::PatternOp>(nested)) {
670  transformOp = cast<TransformOpInterface>(nested);
671  break;
672  }
673  }
674 
675  state.addExtension<PatternApplicatorExtension>(getOperation());
676  auto guard = llvm::make_scope_exit(
677  [&]() { state.removeExtension<PatternApplicatorExtension>(); });
678 
679  auto scope = state.make_region_scope(getBody());
680  if (failed(mapBlockArguments(state)))
682  return state.applyTransform(transformOp);
683 }
684 
686  Block *body = getBodyBlock();
687  Operation *topLevelOp = nullptr;
688  for (Operation &op : body->getOperations()) {
689  if (isa<pdl::PatternOp>(op))
690  continue;
691 
693  if (topLevelOp) {
695  emitOpError() << "expects only one non-pattern op in its body";
696  diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
697  diag.attachNote(op.getLoc()) << "second non-pattern op";
698  return diag;
699  }
700  topLevelOp = &op;
701  continue;
702  }
703 
705  emitOpError()
706  << "expects only pattern and top-level transform ops in its body";
707  diag.attachNote(op.getLoc()) << "offending op";
708  return diag;
709  }
710 
711  if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
712  InFlightDiagnostic diag = emitOpError() << "cannot be nested";
713  diag.attachNote(parent.getLoc()) << "parent operation";
714  return diag;
715  }
716 
717  return success();
718 }
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
Diagnostic & attachNote(Optional< Location > loc=llvm::None)
Attaches a note to the last diagnostic.
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
void set(OpResult value, ArrayRef< Operation *> ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
static std::string diag(llvm::Value &v)
The result of a transform IR operation application.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This is a value defined by a result of an operation.
Definition: Value.h:425
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:295
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:310
Block represents an ordered list of Operations.
Definition: Block.h:29
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
static ParseResult parsePDLOpTypedResults(OpAsmParser &parser, SmallVectorImpl< Type > &types, const SmallVectorImpl< OpAsmParser::UnresolvedOperand > &handles)
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
OpListType & getOperations()
Definition: Block.h:128
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:179
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:74
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
ArrayRef< Operation * > getPayloadOps(Value value) const
Returns the list of ops that the given transform IR value corresponds to.
void removeExtension()
Removes the extension of the specified type.
Operation & front()
Definition: Block.h:144
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
BlockArgListType getArguments()
Definition: Region.h:81
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
void registerConstraintFunction(StringRef name, PDLConstraintFunction constraintFn)
Register a constraint function with PDL.
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
static constexpr const bool value
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic, preserves the other states.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::enable_if< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT >::type walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one)...
Definition: Operation.h:574
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value: ...
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Operation * clone(BlockAndValueMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:554
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Diagnostic & attachNote(Optional< Location > noteLoc=llvm::None)
Attaches a note to this diagnostic.
Definition: Diagnostics.h:348
unsigned getNumArguments()
Definition: Block.h:119
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
#define DBGS()
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
The state maintained across applications of various ops implementing the TransformOpInterface.
static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange, ValueRange)
bool succeeded() const
Returns true if this is a success.
This class represents an argument of a Block.
Definition: Value.h:300
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern&#39;s static benefit.
result_range getOpResults()
Definition: Operation.h:337
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:584
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation *> operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
Ty * getExtension()
Returns the extension of the specified type.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
Type getType() const
Return the type of this value.
Definition: Value.h:118
This class provides the API for ops that are known to be isolated from above.
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
This class manages the application of a group of rewrite patterns, with a user-provided cost model...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Base class for TransformState extensions that allow TransformState to contain user-specified informat...
This class represents an operand of an operation.
Definition: Value.h:251
This class allows for representing and managing the symbol table used by operations with the &#39;SymbolT...
Definition: SymbolTable.h:23
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:40
type_range getTypes() const
Definition: ValueRange.cpp:26
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
bool isa() const
Definition: Types.h:254
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
This class represents success/failure for parsing-like operations that find it important to chain tog...
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
result_range getResults()
Definition: Operation.h:332
static void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:197
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly...
std::string getMessage() const
Returns the diagnostic message without emitting it.