MLIR  16.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformDialect.cpp - Transform dialect operations ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/PatternMatch.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/ScopeExit.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "transform-dialect"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
25 
26 using namespace mlir;
27 
28 #define GET_OP_CLASSES
29 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
30 
31 //===----------------------------------------------------------------------===//
32 // PatternApplicatorExtension
33 //===----------------------------------------------------------------------===//
34 
35 namespace {
36 /// A simple pattern rewriter that can be constructed from a context. This is
37 /// necessary to apply patterns to a specific op locally.
38 class TrivialPatternRewriter : public PatternRewriter {
39 public:
40  explicit TrivialPatternRewriter(MLIRContext *context)
41  : PatternRewriter(context) {}
42 };
43 
44 /// A TransformState extension that keeps track of compiled PDL pattern sets.
45 /// This is intended to be used along the WithPDLPatterns op. The extension
46 /// can be constructed given an operation that has a SymbolTable trait and
47 /// contains pdl::PatternOp instances. The patterns are compiled lazily and one
48 /// by one when requested; this behavior is subject to change.
49 class PatternApplicatorExtension : public transform::TransformState::Extension {
50 public:
51  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
52 
53  /// Creates the extension for patterns contained in `patternContainer`.
54  explicit PatternApplicatorExtension(transform::TransformState &state,
55  Operation *patternContainer)
56  : Extension(state), patterns(patternContainer) {}
57 
58  /// Appends to `results` the operations contained in `root` that matched the
59  /// PDL pattern with the given name. Note that `root` may or may not be the
60  /// operation that contains PDL patterns. Reports an error if the pattern
61  /// cannot be found. Note that when no operations are matched, this still
62  /// succeeds as long as the pattern exists.
63  LogicalResult findAllMatches(StringRef patternName, Operation *root,
65 
66 private:
67  /// Map from the pattern name to a singleton set of rewrite patterns that only
68  /// contains the pattern with this name. Populated when the pattern is first
69  /// requested.
70  // TODO: reconsider the efficiency of this storage when more usage data is
71  // available. Storing individual patterns in a set and triggering compilation
72  // for each of them has overhead. So does compiling a large set of patterns
73  // only to apply a handlful of them.
74  llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
75 
76  /// A symbol table operation containing the relevant PDL patterns.
77  SymbolTable patterns;
78 };
79 
80 LogicalResult PatternApplicatorExtension::findAllMatches(
81  StringRef patternName, Operation *root,
83  auto it = compiledPatterns.find(patternName);
84  if (it == compiledPatterns.end()) {
85  auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
86  if (!patternOp)
87  return failure();
88 
89  OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
90  patternOp->moveBefore(pdlModuleOp->getBody(),
91  pdlModuleOp->getBody()->end());
92  PDLPatternModule patternModule(std::move(pdlModuleOp));
93 
94  // Merge in the hooks owned by the dialect. Make a copy as they may be
95  // also used by the following operations.
96  auto *dialect =
97  root->getContext()->getLoadedDialect<transform::TransformDialect>();
98  for (const auto &pair : dialect->getPDLConstraintHooks())
99  patternModule.registerConstraintFunction(pair.first(), pair.second);
100 
101  // Register a noop rewriter because PDL requires patterns to end with some
102  // rewrite call.
103  patternModule.registerRewriteFunction(
104  "transform.dialect", [](PatternRewriter &, Operation *) {});
105 
106  it = compiledPatterns
107  .try_emplace(patternOp.getName(), std::move(patternModule))
108  .first;
109  }
110 
111  PatternApplicator applicator(it->second);
112  TrivialPatternRewriter rewriter(root->getContext());
113  applicator.applyDefaultCostModel();
114  root->walk([&](Operation *op) {
115  if (succeeded(applicator.matchAndRewrite(op, rewriter)))
116  results.push_back(op);
117  });
118 
119  return success();
120 }
121 } // namespace
122 
123 //===----------------------------------------------------------------------===//
124 // AlternativesOp
125 //===----------------------------------------------------------------------===//
126 
128 transform::AlternativesOp::getSuccessorEntryOperands(Optional<unsigned> index) {
129  if (index && getOperation()->getNumOperands() == 1)
130  return getOperation()->getOperands();
131  return OperandRange(getOperation()->operand_end(),
132  getOperation()->operand_end());
133 }
134 
135 void transform::AlternativesOp::getSuccessorRegions(
136  Optional<unsigned> index, ArrayRef<Attribute> operands,
138  for (Region &alternative : llvm::drop_begin(
139  getAlternatives(), index.has_value() ? *index + 1 : 0)) {
140  regions.emplace_back(&alternative, !getOperands().empty()
141  ? alternative.getArguments()
143  }
144  if (index.has_value())
145  regions.emplace_back(getOperation()->getResults());
146 }
147 
148 void transform::AlternativesOp::getRegionInvocationBounds(
150  (void)operands;
151  // The region corresponding to the first alternative is always executed, the
152  // remaining may or may not be executed.
153  bounds.reserve(getNumRegions());
154  bounds.emplace_back(1, 1);
155  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
156 }
157 
159  transform::TransformResults &results) {
160  for (const auto &res : block->getParentOp()->getOpResults())
161  results.set(res, {});
162 }
163 
164 static void forwardTerminatorOperands(Block *block,
166  transform::TransformResults &results) {
167  for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(),
168  block->getParentOp()->getOpResults())) {
169  Value terminatorOperand = std::get<0>(pair);
170  OpResult result = std::get<1>(pair);
171  results.set(result, state.getPayloadOps(terminatorOperand));
172  }
173 }
174 
176 transform::AlternativesOp::apply(transform::TransformResults &results,
177  transform::TransformState &state) {
178  SmallVector<Operation *> originals;
179  if (Value scopeHandle = getScope())
180  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
181  else
182  originals.push_back(state.getTopLevel());
183 
184  for (Operation *original : originals) {
185  if (original->isAncestor(getOperation())) {
186  auto diag = emitDefiniteFailure()
187  << "scope must not contain the transforms being applied";
188  diag.attachNote(original->getLoc()) << "scope";
189  return diag;
190  }
191  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
192  auto diag = emitDefiniteFailure()
193  << "only isolated-from-above ops can be alternative scopes";
194  diag.attachNote(original->getLoc()) << "scope";
195  return diag;
196  }
197  }
198 
199  for (Region &reg : getAlternatives()) {
200  // Clone the scope operations and make the transforms in this alternative
201  // region apply to them by virtue of mapping the block argument (the only
202  // visible handle) to the cloned scope operations. This effectively prevents
203  // the transformation from accessing any IR outside the scope.
204  auto scope = state.make_region_scope(reg);
205  auto clones = llvm::to_vector(
206  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
207  auto deleteClones = llvm::make_scope_exit([&] {
208  for (Operation *clone : clones)
209  clone->erase();
210  });
211  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
213 
214  bool failed = false;
215  for (Operation &transform : reg.front().without_terminator()) {
217  state.applyTransform(cast<TransformOpInterface>(transform));
218  if (result.isSilenceableFailure()) {
219  LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
220  << "\n");
221  failed = true;
222  break;
223  }
224 
225  if (::mlir::failed(result.silence()))
227  }
228 
229  // If all operations in the given alternative succeeded, no need to consider
230  // the rest. Replace the original scoping operation with the clone on which
231  // the transformations were performed.
232  if (!failed) {
233  // We will be using the clones, so cancel their scheduled deletion.
234  deleteClones.release();
235  IRRewriter rewriter(getContext());
236  for (const auto &kvp : llvm::zip(originals, clones)) {
237  Operation *original = std::get<0>(kvp);
238  Operation *clone = std::get<1>(kvp);
239  original->getBlock()->getOperations().insert(original->getIterator(),
240  clone);
241  rewriter.replaceOp(original, clone->getResults());
242  }
243  forwardTerminatorOperands(&reg.front(), state, results);
245  }
246  }
247  return emitSilenceableError() << "all alternatives failed";
248 }
249 
251  for (Region &alternative : getAlternatives()) {
252  Block &block = alternative.front();
253  Operation *terminator = block.getTerminator();
254  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
255  InFlightDiagnostic diag = emitOpError()
256  << "expects terminator operands to have the "
257  "same type as results of the operation";
258  diag.attachNote(terminator->getLoc()) << "terminator";
259  return diag;
260  }
261  }
262 
263  return success();
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // ForeachOp
268 //===----------------------------------------------------------------------===//
269 
271 transform::CastOp::applyToOne(Operation *target,
273  transform::TransformState &state) {
274  results.push_back(target);
276 }
277 
278 void transform::CastOp::getEffects(
280  onlyReadsPayload(effects);
281  consumesHandle(getInput(), effects);
282  producesHandle(getOutput(), effects);
283 }
284 
285 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
286  assert(inputs.size() == 1 && "expected one input");
287  assert(outputs.size() == 1 && "expected one output");
288  return llvm::all_of(
289  std::initializer_list<Type>{inputs.front(), outputs.front()},
290  [](Type ty) {
291  return ty.isa<pdl::OperationType, transform::TransformTypeInterface>();
292  });
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // ForeachOp
297 //===----------------------------------------------------------------------===//
298 
300 transform::ForeachOp::apply(transform::TransformResults &results,
301  transform::TransformState &state) {
302  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
303  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
304 
305  for (Operation *op : payloadOps) {
306  auto scope = state.make_region_scope(getBody());
307  if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
309 
310  // Execute loop body.
311  for (Operation &transform : getBody().front().without_terminator()) {
313  cast<transform::TransformOpInterface>(transform));
314  if (!result.succeeded())
315  return result;
316  }
317 
318  // Append yielded payload ops to result list (if any).
319  for (unsigned i = 0; i < getNumResults(); ++i) {
320  ArrayRef<Operation *> yieldedOps =
321  state.getPayloadOps(getYieldOp().getOperand(i));
322  resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
323  }
324  }
325 
326  for (unsigned i = 0; i < getNumResults(); ++i)
327  results.set(getResult(i).cast<OpResult>(), resultOps[i]);
328 
330 }
331 
332 void transform::ForeachOp::getEffects(
334  BlockArgument iterVar = getIterationVariable();
335  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
336  return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
337  })) {
338  consumesHandle(getTarget(), effects);
339  } else {
340  onlyReadsHandle(getTarget(), effects);
341  }
342 
343  for (Value result : getResults())
344  producesHandle(result, effects);
345 }
346 
347 void transform::ForeachOp::getSuccessorRegions(
348  Optional<unsigned> index, ArrayRef<Attribute> operands,
350  Region *bodyRegion = &getBody();
351  if (!index) {
352  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
353  return;
354  }
355 
356  // Branch back to the region or the parent.
357  assert(*index == 0 && "unexpected region index");
358  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
359  regions.emplace_back();
360 }
361 
363 transform::ForeachOp::getSuccessorEntryOperands(Optional<unsigned> index) {
364  // The iteration variable op handle is mapped to a subset (one op to be
365  // precise) of the payload ops of the ForeachOp operand.
366  assert(index && *index == 0 && "unexpected region index");
367  return getOperation()->getOperands();
368 }
369 
370 transform::YieldOp transform::ForeachOp::getYieldOp() {
371  return cast<transform::YieldOp>(getBody().front().getTerminator());
372 }
373 
375  auto yieldOp = getYieldOp();
376  if (getNumResults() != yieldOp.getNumOperands())
377  return emitOpError() << "expects the same number of results as the "
378  "terminator has operands";
379  for (Value v : yieldOp.getOperands())
380  if (!v.getType().isa<TransformTypeInterface>())
381  return yieldOp->emitOpError(
382  "expects operands to have types implementing TransformTypeInterface");
383  return success();
384 }
385 
386 //===----------------------------------------------------------------------===//
387 // GetClosestIsolatedParentOp
388 //===----------------------------------------------------------------------===//
389 
390 DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
392  SetVector<Operation *> parents;
393  for (Operation *target : state.getPayloadOps(getTarget())) {
394  Operation *parent =
396  if (!parent) {
398  emitSilenceableError()
399  << "could not find an isolated-from-above parent op";
400  diag.attachNote(target->getLoc()) << "target op";
401  return diag;
402  }
403  parents.insert(parent);
404  }
405  results.set(getResult().cast<OpResult>(), parents.getArrayRef());
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // GetProducerOfOperand
411 //===----------------------------------------------------------------------===//
412 
414 transform::GetProducerOfOperand::apply(transform::TransformResults &results,
415  transform::TransformState &state) {
416  int64_t operandNumber = getOperandNumber();
417  SmallVector<Operation *> producers;
418  for (Operation *target : state.getPayloadOps(getTarget())) {
419  Operation *producer =
420  target->getNumOperands() <= operandNumber
421  ? nullptr
422  : target->getOperand(operandNumber).getDefiningOp();
423  if (!producer) {
425  emitSilenceableError()
426  << "could not find a producer for operand number: " << operandNumber
427  << " of " << *target;
428  diag.attachNote(target->getLoc()) << "target op";
429  results.set(getResult().cast<OpResult>(),
431  return diag;
432  }
433  producers.push_back(producer);
434  }
435  results.set(getResult().cast<OpResult>(), producers);
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // MergeHandlesOp
441 //===----------------------------------------------------------------------===//
442 
444 transform::MergeHandlesOp::apply(transform::TransformResults &results,
445  transform::TransformState &state) {
446  SmallVector<Operation *> operations;
447  for (Value operand : getHandles())
448  llvm::append_range(operations, state.getPayloadOps(operand));
449  if (!getDeduplicate()) {
450  results.set(getResult().cast<OpResult>(), operations);
452  }
453 
454  SetVector<Operation *> uniqued(operations.begin(), operations.end());
455  results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
457 }
458 
459 void transform::MergeHandlesOp::getEffects(
461  consumesHandle(getHandles(), effects);
462  producesHandle(getResult(), effects);
463 
464  // There are no effects on the Payload IR as this is only a handle
465  // manipulation.
466 }
467 
468 OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
469  if (getDeduplicate() || getHandles().size() != 1)
470  return {};
471 
472  // If deduplication is not required and there is only one operand, it can be
473  // used directly instead of merging.
474  return getHandles().front();
475 }
476 
477 //===----------------------------------------------------------------------===//
478 // SplitHandlesOp
479 //===----------------------------------------------------------------------===//
480 
481 void transform::SplitHandlesOp::build(OpBuilder &builder,
482  OperationState &result, Value target,
483  int64_t numResultHandles) {
484  result.addOperands(target);
485  result.addAttribute(SplitHandlesOp::getNumResultHandlesAttrName(result.name),
486  builder.getI64IntegerAttr(numResultHandles));
487  auto pdlOpType = pdl::OperationType::get(builder.getContext());
488  result.addTypes(SmallVector<pdl::OperationType>(numResultHandles, pdlOpType));
489 }
490 
492 transform::SplitHandlesOp::apply(transform::TransformResults &results,
493  transform::TransformState &state) {
494  int64_t numResultHandles =
495  getHandle() ? state.getPayloadOps(getHandle()).size() : 0;
496  int64_t expectedNumResultHandles = getNumResultHandles();
497  if (numResultHandles != expectedNumResultHandles) {
498  // Failing case needs to propagate gracefully for both suppress and
499  // propagate modes.
500  for (int64_t idx = 0; idx < expectedNumResultHandles; ++idx)
501  results.set(getResults()[idx].cast<OpResult>(), {});
502  // Empty input handle corner case: always propagates empty handles in both
503  // suppress and propagate modes.
504  if (numResultHandles == 0)
506  // If the input handle was not empty and the number of result handles does
507  // not match, this is a legit silenceable error.
508  return emitSilenceableError()
509  << getHandle() << " expected to contain " << expectedNumResultHandles
510  << " operation handles but it only contains " << numResultHandles
511  << " handles";
512  }
513  // Normal successful case.
514  for (const auto &en : llvm::enumerate(state.getPayloadOps(getHandle())))
515  results.set(getResults()[en.index()].cast<OpResult>(), en.value());
517 }
518 
519 void transform::SplitHandlesOp::getEffects(
521  consumesHandle(getHandle(), effects);
522  producesHandle(getResults(), effects);
523  // There are no effects on the Payload IR as this is only a handle
524  // manipulation.
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // PDLMatchOp
529 //===----------------------------------------------------------------------===//
530 
532 transform::PDLMatchOp::apply(transform::TransformResults &results,
533  transform::TransformState &state) {
534  auto *extension = state.getExtension<PatternApplicatorExtension>();
535  assert(extension &&
536  "expected PatternApplicatorExtension to be attached by the parent op");
537  SmallVector<Operation *> targets;
538  for (Operation *root : state.getPayloadOps(getRoot())) {
539  if (failed(extension->findAllMatches(
540  getPatternName().getLeafReference().getValue(), root, targets))) {
542  << "could not find pattern '" << getPatternName() << "'";
543  }
544  }
545  results.set(getResult().cast<OpResult>(), targets);
547 }
548 
549 void transform::PDLMatchOp::getEffects(
551  onlyReadsHandle(getRoot(), effects);
552  producesHandle(getMatched(), effects);
553  onlyReadsPayload(effects);
554 }
555 
556 //===----------------------------------------------------------------------===//
557 // ReplicateOp
558 //===----------------------------------------------------------------------===//
559 
561 transform::ReplicateOp::apply(transform::TransformResults &results,
562  transform::TransformState &state) {
563  unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
564  for (const auto &en : llvm::enumerate(getHandles())) {
565  Value handle = en.value();
566  ArrayRef<Operation *> current = state.getPayloadOps(handle);
567  SmallVector<Operation *> payload;
568  payload.reserve(numRepetitions * current.size());
569  for (unsigned i = 0; i < numRepetitions; ++i)
570  llvm::append_range(payload, current);
571  results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
572  }
574 }
575 
576 void transform::ReplicateOp::getEffects(
578  onlyReadsHandle(getPattern(), effects);
579  consumesHandle(getHandles(), effects);
580  producesHandle(getReplicated(), effects);
581 }
582 
583 //===----------------------------------------------------------------------===//
584 // SequenceOp
585 //===----------------------------------------------------------------------===//
586 
588 transform::SequenceOp::apply(transform::TransformResults &results,
589  transform::TransformState &state) {
590  // Map the entry block argument to the list of operations.
591  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
592  if (failed(mapBlockArguments(state)))
594 
595  // Apply the sequenced ops one by one.
596  for (Operation &transform : getBodyBlock()->without_terminator()) {
598  state.applyTransform(cast<TransformOpInterface>(transform));
599  if (result.isDefiniteFailure())
600  return result;
601 
602  if (result.isSilenceableFailure()) {
603  if (getFailurePropagationMode() == FailurePropagationMode::Propagate) {
604  // Propagate empty results in case of early exit.
605  forwardEmptyOperands(getBodyBlock(), state, results);
606  return result;
607  }
608  (void)result.silence();
609  }
610  }
611 
612  // Forward the operation mapping for values yielded from the sequence to the
613  // values produced by the sequence op.
614  forwardTerminatorOperands(getBodyBlock(), state, results);
616 }
617 
618 /// Returns `true` if the given op operand may be consuming the handle value in
619 /// the Transform IR. That is, if it may have a Free effect on it.
621  // Conservatively assume the effect being present in absence of the interface.
622  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
623  if (!iface)
624  return true;
625 
626  return isHandleConsumed(use.get(), iface);
627 }
628 
631  function_ref<InFlightDiagnostic()> reportError) {
632  OpOperand *potentialConsumer = nullptr;
633  for (OpOperand &use : value.getUses()) {
634  if (!isValueUsePotentialConsumer(use))
635  continue;
636 
637  if (!potentialConsumer) {
638  potentialConsumer = &use;
639  continue;
640  }
641 
642  InFlightDiagnostic diag = reportError()
643  << " has more than one potential consumer";
644  diag.attachNote(potentialConsumer->getOwner()->getLoc())
645  << "used here as operand #" << potentialConsumer->getOperandNumber();
646  diag.attachNote(use.getOwner()->getLoc())
647  << "used here as operand #" << use.getOperandNumber();
648  return diag;
649  }
650 
651  return success();
652 }
653 
655  // Check if the block argument has more than one consuming use.
656  for (BlockArgument argument : getBodyBlock()->getArguments()) {
657  auto report = [&]() {
658  return (emitOpError() << "block argument #" << argument.getArgNumber());
659  };
660  if (failed(checkDoubleConsume(argument, report)))
661  return failure();
662  }
663 
664  // Check properties of the nested operations they cannot check themselves.
665  for (Operation &child : *getBodyBlock()) {
666  if (!isa<TransformOpInterface>(child) &&
667  &child != &getBodyBlock()->back()) {
669  emitOpError()
670  << "expected children ops to implement TransformOpInterface";
671  diag.attachNote(child.getLoc()) << "op without interface";
672  return diag;
673  }
674 
675  for (OpResult result : child.getResults()) {
676  auto report = [&]() {
677  return (child.emitError() << "result #" << result.getResultNumber());
678  };
679  if (failed(checkDoubleConsume(result, report)))
680  return failure();
681  }
682  }
683 
684  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
685  getOperation()->getResultTypes()) {
686  InFlightDiagnostic diag = emitOpError()
687  << "expects the types of the terminator operands "
688  "to match the types of the result";
689  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
690  return diag;
691  }
692  return success();
693 }
694 
695 void transform::SequenceOp::getEffects(
697  auto *mappingResource = TransformMappingResource::get();
698  effects.emplace_back(MemoryEffects::Read::get(), getRoot(), mappingResource);
699 
700  for (Value result : getResults()) {
701  effects.emplace_back(MemoryEffects::Allocate::get(), result,
702  mappingResource);
703  effects.emplace_back(MemoryEffects::Write::get(), result, mappingResource);
704  }
705 
706  if (!getRoot()) {
707  for (Operation &op : *getBodyBlock()) {
708  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
709  if (!iface) {
710  // TODO: fill all possible effects; or require ops to actually implement
711  // the memory effect interface always
712  assert(false);
713  }
714 
716  iface.getEffects(effects);
717  }
718  return;
719  }
720 
721  // Carry over all effects on the argument of the entry block as those on the
722  // operand, this is the same value just remapped.
723  for (Operation &op : *getBodyBlock()) {
724  auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
725  if (!iface) {
726  // TODO: fill all possible effects; or require ops to actually implement
727  // the memory effect interface always
728  assert(false);
729  }
730 
732  iface.getEffectsOnValue(getBodyBlock()->getArgument(0), nestedEffects);
733  for (const auto &effect : nestedEffects)
734  effects.emplace_back(effect.getEffect(), getRoot(), effect.getResource());
735  }
736 }
737 
739 transform::SequenceOp::getSuccessorEntryOperands(Optional<unsigned> index) {
740  assert(index && *index == 0 && "unexpected region index");
741  if (getOperation()->getNumOperands() == 1)
742  return getOperation()->getOperands();
743  return OperandRange(getOperation()->operand_end(),
744  getOperation()->operand_end());
745 }
746 
747 void transform::SequenceOp::getSuccessorRegions(
748  Optional<unsigned> index, ArrayRef<Attribute> operands,
750  if (!index) {
751  Region *bodyRegion = &getBody();
752  regions.emplace_back(bodyRegion, !operands.empty()
753  ? bodyRegion->getArguments()
755  return;
756  }
757 
758  assert(*index == 0 && "unexpected region index");
759  regions.emplace_back(getOperation()->getResults());
760 }
761 
762 void transform::SequenceOp::getRegionInvocationBounds(
764  (void)operands;
765  bounds.emplace_back(1, 1);
766 }
767 
768 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
769  TypeRange resultTypes,
770  FailurePropagationMode failurePropagationMode,
771  Value root,
772  SequenceBodyBuilderFn bodyBuilder) {
773  build(builder, state, resultTypes, failurePropagationMode, root);
774  Region *region = state.regions.back().get();
775  auto bbArgType = root.getType();
776  Block *bodyBlock = builder.createBlock(
777  region, region->begin(), TypeRange{bbArgType}, {state.location});
778 
779  // Populate body.
780  OpBuilder::InsertionGuard guard(builder);
781  builder.setInsertionPointToStart(bodyBlock);
782  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
783 }
784 
785 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
786  TypeRange resultTypes,
787  FailurePropagationMode failurePropagationMode,
788  Type bbArgType,
789  SequenceBodyBuilderFn bodyBuilder) {
790  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value());
791  Region *region = state.regions.back().get();
792  Block *bodyBlock = builder.createBlock(
793  region, region->begin(), TypeRange{bbArgType}, {state.location});
794 
795  // Populate body.
796  OpBuilder::InsertionGuard guard(builder);
797  builder.setInsertionPointToStart(bodyBlock);
798  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
799 }
800 
801 //===----------------------------------------------------------------------===//
802 // WithPDLPatternsOp
803 //===----------------------------------------------------------------------===//
804 
806 transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
807  transform::TransformState &state) {
808  OwningOpRef<ModuleOp> pdlModuleOp =
809  ModuleOp::create(getOperation()->getLoc());
810  TransformOpInterface transformOp = nullptr;
811  for (Operation &nested : getBody().front()) {
812  if (!isa<pdl::PatternOp>(nested)) {
813  transformOp = cast<TransformOpInterface>(nested);
814  break;
815  }
816  }
817 
818  state.addExtension<PatternApplicatorExtension>(getOperation());
819  auto guard = llvm::make_scope_exit(
820  [&]() { state.removeExtension<PatternApplicatorExtension>(); });
821 
822  auto scope = state.make_region_scope(getBody());
823  if (failed(mapBlockArguments(state)))
825  return state.applyTransform(transformOp);
826 }
827 
829  Block *body = getBodyBlock();
830  Operation *topLevelOp = nullptr;
831  for (Operation &op : body->getOperations()) {
832  if (isa<pdl::PatternOp>(op))
833  continue;
834 
836  if (topLevelOp) {
838  emitOpError() << "expects only one non-pattern op in its body";
839  diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
840  diag.attachNote(op.getLoc()) << "second non-pattern op";
841  return diag;
842  }
843  topLevelOp = &op;
844  continue;
845  }
846 
848  emitOpError()
849  << "expects only pattern and top-level transform ops in its body";
850  diag.attachNote(op.getLoc()) << "offending op";
851  return diag;
852  }
853 
854  if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
855  InFlightDiagnostic diag = emitOpError() << "cannot be nested";
856  diag.attachNote(parent.getLoc()) << "parent operation";
857  return diag;
858  }
859 
860  return success();
861 }
862 
863 //===----------------------------------------------------------------------===//
864 // PrintOp
865 //===----------------------------------------------------------------------===//
866 
867 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
868  StringRef name) {
869  if (!name.empty()) {
870  result.addAttribute(PrintOp::getNameAttrName(result.name),
871  builder.getStrArrayAttr(name));
872  }
873 }
874 
875 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
876  Value target, StringRef name) {
877  result.addOperands({target});
878  build(builder, result, name);
879 }
880 
882 transform::PrintOp::apply(transform::TransformResults &results,
883  transform::TransformState &state) {
884  llvm::errs() << "[[[ IR printer: ";
885  if (getName().has_value())
886  llvm::errs() << *getName() << " ";
887 
888  if (!getTarget()) {
889  llvm::errs() << "top-level ]]]\n" << *state.getTopLevel() << "\n";
891  }
892 
893  llvm::errs() << "]]]\n";
894  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
895  for (Operation *target : targets)
896  llvm::errs() << *target << "\n";
897 
899 }
900 
901 void transform::PrintOp::getEffects(
903  onlyReadsHandle(getTarget(), effects);
904  onlyReadsPayload(effects);
905 
906  // There is no resource for stderr file descriptor, so just declare print
907  // writes into the default resource.
908  effects.emplace_back(MemoryEffects::Write::get());
909 }
static std::string diag(llvm::Value &value)
static constexpr const bool value
static void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
#define DBGS()
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
OpListType & getOperations()
Definition: Block.h:126
Operation & front()
Definition: Block.h:142
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:113
MLIRContext * getContext() const
Definition: Builders.h:54
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:288
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:589
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:307
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:395
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
This class represents an operand of an operation.
Definition: Value.h:247
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This is a value defined by a result of an operation.
Definition: Value.h:442
This class provides the API for ops that are known to be isolated from above.
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:41
type_range getTypes() const
Definition: ValueRange.cpp:26
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
Operation * clone(BlockAndValueMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:558
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
unsigned getNumOperands()
Definition: Operation.h:263
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:144
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:179
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
result_range getOpResults()
Definition: Operation.h:337
result_range getResults()
Definition: Operation.h:332
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:574
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:418
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn)
Register a rewrite function with PDL.
void registerConstraintFunction(StringRef name, PDLConstraintFunction constraintFn)
Register a constraint function with PDL.
This class manages the application of a group of rewrite patterns, with a user-provided cost model.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter, function_ref< bool(const Pattern &)> canApply={}, function_ref< void(const Pattern &)> onFailure={}, function_ref< LogicalResult(const Pattern &)> onSuccess={})
Attempt to match and rewrite the given op with any pattern, allowing a predicate to decide if a patte...
void applyDefaultCostModel()
Apply the default cost model that solely uses the pattern's static benefit.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:605
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
iterator begin()
Definition: Region.h:55
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:23
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:40
This trait is supposed to be attached to Transform dialect operations that can be standalone top-leve...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void set(OpResult value, ArrayRef< Operation * > ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
Base class for TransformState extensions that allow TransformState to contain user-specified informat...
The state maintained across applications of various ops implementing the TransformOpInterface.
LogicalResult mapBlockArguments(BlockArgument argument, ArrayRef< Operation * > operations)
Records the mapping between a block argument in the transform IR and a list of operations in the payl...
DiagnosedSilenceableFailure applyTransform(TransformOpInterface transform)
Applies the transformation specified by the given transform op and updates the state accordingly.
RegionScope make_region_scope(Region &region)
Creates a new region scope for the given region.
ArrayRef< Operation * > getPayloadOps(Value value) const
Returns the list of ops that the given transform IR value corresponds to.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
Ty * getExtension()
Returns the extension of the specified type.
Operation * getTopLevel() const
Returns the op at which the transformation state is rooted.
void removeExtension()
Removes the extension of the specified type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)> SequenceBodyBuilderFn
A builder function that populates the body of a SequenceOp.
Definition: TransformOps.h:27
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:372
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.