MLIR  19.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/Verifier.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Pass/PassManager.h"
31 #include "mlir/Pass/PassRegistry.h"
32 #include "mlir/Transforms/CSE.h"
36 #include "llvm/ADT/DenseSet.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/ScopeExit.h"
39 #include "llvm/ADT/SmallPtrSet.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/ErrorHandling.h"
43 #include <optional>
44 
45 #define DEBUG_TYPE "transform-dialect"
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
47 
48 #define DEBUG_TYPE_MATCHER "transform-matcher"
49 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
50 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
51 
52 using namespace mlir;
53 
55  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
56  Type &rootType,
57  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
58  SmallVectorImpl<Type> &extraBindingTypes);
59 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
60  Value root, Type rootType,
61  ValueRange extraBindings,
62  TypeRange extraBindingTypes);
63 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
64  ArrayAttr matchers, ArrayAttr actions);
66  ArrayAttr &matchers,
67  ArrayAttr &actions);
68 
69 /// Helper function to check if the given transform op is contained in (or
70 /// equal to) the given payload target op. In that case, an error is returned.
71 /// Transforming transform IR that is currently executing is generally unsafe.
73 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
74  Operation *payload) {
75  Operation *transformAncestor = transform.getOperation();
76  while (transformAncestor) {
77  if (transformAncestor == payload) {
79  transform.emitDefiniteFailure()
80  << "cannot apply transform to itself (or one of its ancestors)";
81  diag.attachNote(payload->getLoc()) << "target payload op";
82  return diag;
83  }
84  transformAncestor = transformAncestor->getParentOp();
85  }
87 }
88 
89 #define GET_OP_CLASSES
90 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
91 
92 //===----------------------------------------------------------------------===//
93 // AlternativesOp
94 //===----------------------------------------------------------------------===//
95 
97 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
98  if (!point.isParent() && getOperation()->getNumOperands() == 1)
99  return getOperation()->getOperands();
100  return OperandRange(getOperation()->operand_end(),
101  getOperation()->operand_end());
102 }
103 
104 void transform::AlternativesOp::getSuccessorRegions(
105  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
106  for (Region &alternative : llvm::drop_begin(
107  getAlternatives(),
108  point.isParent() ? 0
109  : point.getRegionOrNull()->getRegionNumber() + 1)) {
110  regions.emplace_back(&alternative, !getOperands().empty()
111  ? alternative.getArguments()
113  }
114  if (!point.isParent())
115  regions.emplace_back(getOperation()->getResults());
116 }
117 
118 void transform::AlternativesOp::getRegionInvocationBounds(
119  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
120  (void)operands;
121  // The region corresponding to the first alternative is always executed, the
122  // remaining may or may not be executed.
123  bounds.reserve(getNumRegions());
124  bounds.emplace_back(1, 1);
125  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
126 }
127 
129  transform::TransformResults &results) {
130  for (const auto &res : block->getParentOp()->getOpResults())
131  results.set(res, {});
132 }
133 
135 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
137  transform::TransformState &state) {
138  SmallVector<Operation *> originals;
139  if (Value scopeHandle = getScope())
140  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
141  else
142  originals.push_back(state.getTopLevel());
143 
144  for (Operation *original : originals) {
145  if (original->isAncestor(getOperation())) {
146  auto diag = emitDefiniteFailure()
147  << "scope must not contain the transforms being applied";
148  diag.attachNote(original->getLoc()) << "scope";
149  return diag;
150  }
151  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
152  auto diag = emitDefiniteFailure()
153  << "only isolated-from-above ops can be alternative scopes";
154  diag.attachNote(original->getLoc()) << "scope";
155  return diag;
156  }
157  }
158 
159  for (Region &reg : getAlternatives()) {
160  // Clone the scope operations and make the transforms in this alternative
161  // region apply to them by virtue of mapping the block argument (the only
162  // visible handle) to the cloned scope operations. This effectively prevents
163  // the transformation from accessing any IR outside the scope.
164  auto scope = state.make_region_scope(reg);
165  auto clones = llvm::to_vector(
166  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
167  auto deleteClones = llvm::make_scope_exit([&] {
168  for (Operation *clone : clones)
169  clone->erase();
170  });
171  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
173 
174  bool failed = false;
175  for (Operation &transform : reg.front().without_terminator()) {
177  state.applyTransform(cast<TransformOpInterface>(transform));
178  if (result.isSilenceableFailure()) {
179  LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
180  << "\n");
181  failed = true;
182  break;
183  }
184 
185  if (::mlir::failed(result.silence()))
187  }
188 
189  // If all operations in the given alternative succeeded, no need to consider
190  // the rest. Replace the original scoping operation with the clone on which
191  // the transformations were performed.
192  if (!failed) {
193  // We will be using the clones, so cancel their scheduled deletion.
194  deleteClones.release();
195  TrackingListener listener(state, *this);
196  IRRewriter rewriter(getContext(), &listener);
197  for (const auto &kvp : llvm::zip(originals, clones)) {
198  Operation *original = std::get<0>(kvp);
199  Operation *clone = std::get<1>(kvp);
200  original->getBlock()->getOperations().insert(original->getIterator(),
201  clone);
202  rewriter.replaceOp(original, clone->getResults());
203  }
204  detail::forwardTerminatorOperands(&reg.front(), state, results);
206  }
207  }
208  return emitSilenceableError() << "all alternatives failed";
209 }
210 
211 void transform::AlternativesOp::getEffects(
212  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
213  consumesHandle(getOperands(), effects);
214  producesHandle(getResults(), effects);
215  for (Region *region : getRegions()) {
216  if (!region->empty())
217  producesHandle(region->front().getArguments(), effects);
218  }
219  modifiesPayload(effects);
220 }
221 
223  for (Region &alternative : getAlternatives()) {
224  Block &block = alternative.front();
225  Operation *terminator = block.getTerminator();
226  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
227  InFlightDiagnostic diag = emitOpError()
228  << "expects terminator operands to have the "
229  "same type as results of the operation";
230  diag.attachNote(terminator->getLoc()) << "terminator";
231  return diag;
232  }
233  }
234 
235  return success();
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // AnnotateOp
240 //===----------------------------------------------------------------------===//
241 
243 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
245  transform::TransformState &state) {
246  SmallVector<Operation *> targets =
247  llvm::to_vector(state.getPayloadOps(getTarget()));
248 
250  if (auto paramH = getParam()) {
251  ArrayRef<Attribute> params = state.getParams(paramH);
252  if (params.size() != 1) {
253  if (targets.size() != params.size()) {
254  return emitSilenceableError()
255  << "parameter and target have different payload lengths ("
256  << params.size() << " vs " << targets.size() << ")";
257  }
258  for (auto &&[target, attr] : llvm::zip_equal(targets, params))
259  target->setAttr(getName(), attr);
261  }
262  attr = params[0];
263  }
264  for (auto *target : targets)
265  target->setAttr(getName(), attr);
267 }
268 
269 void transform::AnnotateOp::getEffects(
270  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
271  onlyReadsHandle(getTarget(), effects);
272  onlyReadsHandle(getParam(), effects);
273  modifiesPayload(effects);
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // ApplyCommonSubexpressionEliminationOp
278 //===----------------------------------------------------------------------===//
279 
281 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
282  transform::TransformRewriter &rewriter, Operation *target,
283  ApplyToEachResultList &results, transform::TransformState &state) {
284  // Make sure that this transform is not applied to itself. Modifying the
285  // transform IR while it is being interpreted is generally dangerous.
286  DiagnosedSilenceableFailure payloadCheck =
288  if (!payloadCheck.succeeded())
289  return payloadCheck;
290 
291  DominanceInfo domInfo;
292  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
294 }
295 
296 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
297  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
298  transform::onlyReadsHandle(getTarget(), effects);
300 }
301 
302 //===----------------------------------------------------------------------===//
303 // ApplyDeadCodeEliminationOp
304 //===----------------------------------------------------------------------===//
305 
306 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
307  transform::TransformRewriter &rewriter, Operation *target,
308  ApplyToEachResultList &results, transform::TransformState &state) {
309  // Make sure that this transform is not applied to itself. Modifying the
310  // transform IR while it is being interpreted is generally dangerous.
311  DiagnosedSilenceableFailure payloadCheck =
313  if (!payloadCheck.succeeded())
314  return payloadCheck;
315 
316  // Maintain a worklist of potentially dead ops.
317  SetVector<Operation *> worklist;
318 
319  // Helper function that adds all defining ops of used values (operands and
320  // operands of nested ops).
321  auto addDefiningOpsToWorklist = [&](Operation *op) {
322  op->walk([&](Operation *op) {
323  for (Value v : op->getOperands())
324  if (Operation *defOp = v.getDefiningOp())
325  if (target->isProperAncestor(defOp))
326  worklist.insert(defOp);
327  });
328  };
329 
330  // Helper function that erases an op.
331  auto eraseOp = [&](Operation *op) {
332  // Remove op and nested ops from the worklist.
333  op->walk([&](Operation *op) {
334  const auto *it = llvm::find(worklist, op);
335  if (it != worklist.end())
336  worklist.erase(it);
337  });
338  rewriter.eraseOp(op);
339  };
340 
341  // Initial walk over the IR.
342  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
343  if (op != target && isOpTriviallyDead(op)) {
344  addDefiningOpsToWorklist(op);
345  eraseOp(op);
346  }
347  });
348 
349  // Erase all ops that have become dead.
350  while (!worklist.empty()) {
351  Operation *op = worklist.pop_back_val();
352  if (!isOpTriviallyDead(op))
353  continue;
354  addDefiningOpsToWorklist(op);
355  eraseOp(op);
356  }
357 
359 }
360 
361 void transform::ApplyDeadCodeEliminationOp::getEffects(
362  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
363  transform::onlyReadsHandle(getTarget(), effects);
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // ApplyPatternsOp
369 //===----------------------------------------------------------------------===//
370 
371 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
372  transform::TransformRewriter &rewriter, Operation *target,
373  ApplyToEachResultList &results, transform::TransformState &state) {
374  // Make sure that this transform is not applied to itself. Modifying the
375  // transform IR while it is being interpreted is generally dangerous. Even
376  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
377  // performs many additional simplifications such as dead code elimination.
378  DiagnosedSilenceableFailure payloadCheck =
380  if (!payloadCheck.succeeded())
381  return payloadCheck;
382 
383  // Gather all specified patterns.
384  MLIRContext *ctx = target->getContext();
385  RewritePatternSet patterns(ctx);
386  if (!getRegion().empty()) {
387  for (Operation &op : getRegion().front()) {
388  cast<transform::PatternDescriptorOpInterface>(&op)
389  .populatePatternsWithState(patterns, state);
390  }
391  }
392 
393  // Configure the GreedyPatternRewriteDriver.
394  GreedyRewriteConfig config;
395  config.listener =
396  static_cast<RewriterBase::Listener *>(rewriter.getListener());
397  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
398 
399  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
400  // was requested, apply the greedy pattern rewrite only once. (The greedy
401  // pattern rewrite driver already iterates to a fixpoint internally.)
402  bool cseChanged = false;
403  // One or two iterations should be sufficient. Stop iterating after a certain
404  // threshold to make debugging easier.
405  static const int64_t kNumMaxIterations = 50;
406  int64_t iteration = 0;
407  do {
408  LogicalResult result = failure();
409  if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
410  // Op is isolated from above. Apply patterns and also perform region
411  // simplification.
412  result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
413  } else {
414  // Manually gather list of ops because the other
415  // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
416  // from above. This way, patterns can be applied to ops that are not
417  // isolated from above. Regions are not being simplified. Furthermore,
418  // only a single greedy rewrite iteration is performed.
420  target->walk([&](Operation *nestedOp) {
421  if (target != nestedOp)
422  ops.push_back(nestedOp);
423  });
424  result = applyOpPatternsAndFold(ops, frozenPatterns, config);
425  }
426 
427  // A failure typically indicates that the pattern application did not
428  // converge.
429  if (failed(result)) {
430  return emitSilenceableFailure(target)
431  << "greedy pattern application failed";
432  }
433 
434  if (getApplyCse()) {
435  DominanceInfo domInfo;
436  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
437  &cseChanged);
438  }
439  } while (cseChanged && ++iteration < kNumMaxIterations);
440 
441  if (iteration == kNumMaxIterations)
442  return emitDefiniteFailure() << "fixpoint iteration did not converge";
443 
445 }
446 
448  if (!getRegion().empty()) {
449  for (Operation &op : getRegion().front()) {
450  if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
451  InFlightDiagnostic diag = emitOpError()
452  << "expected children ops to implement "
453  "PatternDescriptorOpInterface";
454  diag.attachNote(op.getLoc()) << "op without interface";
455  return diag;
456  }
457  }
458  }
459  return success();
460 }
461 
462 void transform::ApplyPatternsOp::getEffects(
463  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
464  transform::onlyReadsHandle(getTarget(), effects);
466 }
467 
468 void transform::ApplyPatternsOp::build(
469  OpBuilder &builder, OperationState &result, Value target,
470  function_ref<void(OpBuilder &, Location)> bodyBuilder) {
471  result.addOperands(target);
472 
473  OpBuilder::InsertionGuard g(builder);
474  Region *region = result.addRegion();
475  builder.createBlock(region);
476  if (bodyBuilder)
477  bodyBuilder(builder, result.location);
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // ApplyCanonicalizationPatternsOp
482 //===----------------------------------------------------------------------===//
483 
484 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
485  RewritePatternSet &patterns) {
486  MLIRContext *ctx = patterns.getContext();
487  for (Dialect *dialect : ctx->getLoadedDialects())
488  dialect->getCanonicalizationPatterns(patterns);
490  op.getCanonicalizationPatterns(patterns, ctx);
491 }
492 
493 //===----------------------------------------------------------------------===//
494 // ApplyConversionPatternsOp
495 //===----------------------------------------------------------------------===//
496 
497 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
500  MLIRContext *ctx = getContext();
501 
502  // Instantiate the default type converter if a type converter builder is
503  // specified.
504  std::unique_ptr<TypeConverter> defaultTypeConverter;
505  transform::TypeConverterBuilderOpInterface typeConverterBuilder =
506  getDefaultTypeConverter();
507  if (typeConverterBuilder)
508  defaultTypeConverter = typeConverterBuilder.getTypeConverter();
509 
510  // Configure conversion target.
511  ConversionTarget conversionTarget(*getContext());
512  if (getLegalOps())
513  for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
514  conversionTarget.addLegalOp(
515  OperationName(cast<StringAttr>(attr).getValue(), ctx));
516  if (getIllegalOps())
517  for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
518  conversionTarget.addIllegalOp(
519  OperationName(cast<StringAttr>(attr).getValue(), ctx));
520  if (getLegalDialects())
521  for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
522  conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
523  if (getIllegalDialects())
524  for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
525  conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
526 
527  // Gather all specified patterns.
528  RewritePatternSet patterns(ctx);
529  // Need to keep the converters alive until after pattern application because
530  // the patterns take a reference to an object that would otherwise get out of
531  // scope.
532  SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
533  if (!getPatterns().empty()) {
534  for (Operation &op : getPatterns().front()) {
535  auto descriptor =
536  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
537 
538  // Check if this pattern set specifies a type converter.
539  std::unique_ptr<TypeConverter> typeConverter =
540  descriptor.getTypeConverter();
541  TypeConverter *converter = nullptr;
542  if (typeConverter) {
543  keepAliveConverters.emplace_back(std::move(typeConverter));
544  converter = keepAliveConverters.back().get();
545  } else {
546  // No type converter specified: Use the default type converter.
547  if (!defaultTypeConverter) {
548  auto diag = emitDefiniteFailure()
549  << "pattern descriptor does not specify type "
550  "converter and apply_conversion_patterns op has "
551  "no default type converter";
552  diag.attachNote(op.getLoc()) << "pattern descriptor op";
553  return diag;
554  }
555  converter = defaultTypeConverter.get();
556  }
557 
558  // Add descriptor-specific updates to the conversion target, which may
559  // depend on the final type converter. In structural converters, the
560  // legality of types dictates the dynamic legality of an operation.
561  descriptor.populateConversionTargetRules(*converter, conversionTarget);
562 
563  descriptor.populatePatterns(*converter, patterns);
564  }
565  }
566 
567  // Attach a tracking listener if handles should be preserved. We configure the
568  // listener to allow op replacements with different names, as conversion
569  // patterns typically replace ops with replacement ops that have a different
570  // name.
571  TrackingListenerConfig trackingConfig;
572  trackingConfig.requireMatchingReplacementOpName = false;
573  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
574  ConversionConfig conversionConfig;
575  if (getPreserveHandles())
576  conversionConfig.listener = &trackingListener;
577 
578  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
579  for (Operation *target : state.getPayloadOps(getTarget())) {
580  // Make sure that this transform is not applied to itself. Modifying the
581  // transform IR while it is being interpreted is generally dangerous.
582  DiagnosedSilenceableFailure payloadCheck =
584  if (!payloadCheck.succeeded())
585  return payloadCheck;
586 
587  LogicalResult status = failure();
588  if (getPartialConversion()) {
589  status = applyPartialConversion(target, conversionTarget, frozenPatterns,
590  conversionConfig);
591  } else {
592  status = applyFullConversion(target, conversionTarget, frozenPatterns,
593  conversionConfig);
594  }
595 
596  // Check dialect conversion state.
598  if (failed(status)) {
599  diag = emitSilenceableError() << "dialect conversion failed";
600  diag.attachNote(target->getLoc()) << "target op";
601  }
602 
603  // Check tracking listener error state.
604  DiagnosedSilenceableFailure trackingFailure =
605  trackingListener.checkAndResetError();
606  if (!trackingFailure.succeeded()) {
607  if (diag.succeeded()) {
608  // Tracking failure is the only failure.
609  return trackingFailure;
610  } else {
611  diag.attachNote() << "tracking listener also failed: "
612  << trackingFailure.getMessage();
613  (void)trackingFailure.silence();
614  }
615  }
616 
617  if (!diag.succeeded())
618  return diag;
619  }
620 
622 }
623 
625  if (getNumRegions() != 1 && getNumRegions() != 2)
626  return emitOpError() << "expected 1 or 2 regions";
627  if (!getPatterns().empty()) {
628  for (Operation &op : getPatterns().front()) {
629  if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
631  emitOpError() << "expected pattern children ops to implement "
632  "ConversionPatternDescriptorOpInterface";
633  diag.attachNote(op.getLoc()) << "op without interface";
634  return diag;
635  }
636  }
637  }
638  if (getNumRegions() == 2) {
639  Region &typeConverterRegion = getRegion(1);
640  if (!llvm::hasSingleElement(typeConverterRegion.front()))
641  return emitOpError()
642  << "expected exactly one op in default type converter region";
643  auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
644  &typeConverterRegion.front().front());
645  if (!typeConverterOp) {
646  InFlightDiagnostic diag = emitOpError()
647  << "expected default converter child op to "
648  "implement TypeConverterBuilderOpInterface";
649  diag.attachNote(typeConverterOp->getLoc()) << "op without interface";
650  return diag;
651  }
652  // Check default type converter type.
653  if (!getPatterns().empty()) {
654  for (Operation &op : getPatterns().front()) {
655  auto descriptor =
656  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
657  if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
658  return failure();
659  }
660  }
661  }
662  return success();
663 }
664 
665 void transform::ApplyConversionPatternsOp::getEffects(
666  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
667  if (!getPreserveHandles()) {
668  transform::consumesHandle(getTarget(), effects);
669  } else {
670  transform::onlyReadsHandle(getTarget(), effects);
671  }
673 }
674 
675 void transform::ApplyConversionPatternsOp::build(
676  OpBuilder &builder, OperationState &result, Value target,
677  function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
678  function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
679  result.addOperands(target);
680 
681  {
682  OpBuilder::InsertionGuard g(builder);
683  Region *region1 = result.addRegion();
684  builder.createBlock(region1);
685  if (patternsBodyBuilder)
686  patternsBodyBuilder(builder, result.location);
687  }
688  {
689  OpBuilder::InsertionGuard g(builder);
690  Region *region2 = result.addRegion();
691  builder.createBlock(region2);
692  if (typeConverterBodyBuilder)
693  typeConverterBodyBuilder(builder, result.location);
694  }
695 }
696 
697 //===----------------------------------------------------------------------===//
698 // ApplyToLLVMConversionPatternsOp
699 //===----------------------------------------------------------------------===//
700 
701 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
702  TypeConverter &typeConverter, RewritePatternSet &patterns) {
703  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
704  assert(dialect && "expected that dialect is loaded");
705  auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
706  // ConversionTarget is currently ignored because the enclosing
707  // apply_conversion_patterns op sets up its own ConversionTarget.
708  ConversionTarget target(*getContext());
709  iface->populateConvertToLLVMConversionPatterns(
710  target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
711 }
712 
713 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
714  transform::TypeConverterBuilderOpInterface builder) {
715  if (builder.getTypeConverterType() != "LLVMTypeConverter")
716  return emitOpError("expected LLVMTypeConverter");
717  return success();
718 }
719 
721  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
722  if (!dialect)
723  return emitOpError("unknown dialect or dialect not loaded: ")
724  << getDialectName();
725  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
726  if (!iface)
727  return emitOpError(
728  "dialect does not implement ConvertToLLVMPatternInterface or "
729  "extension was not loaded: ")
730  << getDialectName();
731  return success();
732 }
733 
734 //===----------------------------------------------------------------------===//
735 // ApplyLoopInvariantCodeMotionOp
736 //===----------------------------------------------------------------------===//
737 
739 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
740  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
742  transform::TransformState &state) {
743  // Currently, LICM does not remove operations, so we don't need tracking.
744  // If this ever changes, add a LICM entry point that takes a rewriter.
745  moveLoopInvariantCode(target);
747 }
748 
749 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
750  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
751  transform::onlyReadsHandle(getTarget(), effects);
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // ApplyRegisteredPassOp
757 //===----------------------------------------------------------------------===//
758 
759 DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
760  transform::TransformRewriter &rewriter, Operation *target,
761  ApplyToEachResultList &results, transform::TransformState &state) {
762  // Make sure that this transform is not applied to itself. Modifying the
763  // transform IR while it is being interpreted is generally dangerous. Even
764  // more so when applying passes because they may perform a wide range of IR
765  // modifications.
766  DiagnosedSilenceableFailure payloadCheck =
768  if (!payloadCheck.succeeded())
769  return payloadCheck;
770 
771  // Get pass or pass pipeline from registry.
772  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
773  if (!info)
774  info = PassInfo::lookup(getPassName());
775  if (!info)
776  return emitDefiniteFailure()
777  << "unknown pass or pass pipeline: " << getPassName();
778 
779  // Create pass manager and run the pass or pass pipeline.
780  PassManager pm(getContext());
781  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
782  emitError(msg);
783  return failure();
784  }))) {
785  return emitDefiniteFailure()
786  << "failed to add pass or pass pipeline to pipeline: "
787  << getPassName();
788  }
789  if (failed(pm.run(target))) {
790  auto diag = emitSilenceableError() << "pass pipeline failed";
791  diag.attachNote(target->getLoc()) << "target op";
792  return diag;
793  }
794 
795  results.push_back(target);
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // CastOp
801 //===----------------------------------------------------------------------===//
802 
804 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
805  Operation *target, ApplyToEachResultList &results,
806  transform::TransformState &state) {
807  results.push_back(target);
809 }
810 
811 void transform::CastOp::getEffects(
812  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
813  onlyReadsPayload(effects);
814  onlyReadsHandle(getInput(), effects);
815  producesHandle(getOutput(), effects);
816 }
817 
819  assert(inputs.size() == 1 && "expected one input");
820  assert(outputs.size() == 1 && "expected one output");
821  return llvm::all_of(
822  std::initializer_list<Type>{inputs.front(), outputs.front()},
823  llvm::IsaPred<transform::TransformHandleTypeInterface>);
824 }
825 
826 //===----------------------------------------------------------------------===//
827 // CollectMatchingOp
828 //===----------------------------------------------------------------------===//
829 
830 /// Applies matcher operations from the given `block` assigning `op` as the
831 /// payload of the block's first argument. Updates `state` accordingly. If any
832 /// of the matcher produces a silenceable failure, discards it (printing the
833 /// content to the debug output stream) and returns failure. If any of the
834 /// matchers produces a definite failure, reports it and returns failure. If all
835 /// matchers in the block succeed, populates `mappings` with the payload
836 /// entities associated with the block terminator operands.
839  SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
840  assert(block.getParent() && "cannot match using a detached block");
841  auto matchScope = state.make_region_scope(*block.getParent());
842  if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
844 
845  for (Operation &match : block.without_terminator()) {
846  if (!isa<transform::MatchOpInterface>(match)) {
847  return emitDefiniteFailure(match.getLoc())
848  << "expected operations in the match part to "
849  "implement MatchOpInterface";
850  }
852  state.applyTransform(cast<transform::TransformOpInterface>(match));
853  if (diag.succeeded())
854  continue;
855 
856  return diag;
857  }
858 
859  // Remember the values mapped to the terminator operands so we can
860  // forward them to the action.
861  ValueRange yieldedValues = block.getTerminator()->getOperands();
862  transform::detail::prepareValueMappings(mappings, yieldedValues, state);
864 }
865 
866 /// Returns `true` if both types implement one of the interfaces provided as
867 /// template parameters.
868 template <typename... Tys>
869 static bool implementSameInterface(Type t1, Type t2) {
870  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
871 }
872 
873 /// Returns `true` if both types implement one of the transform dialect
874 /// interfaces.
876  return implementSameInterface<transform::TransformHandleTypeInterface,
877  transform::TransformParamTypeInterface,
878  transform::TransformValueHandleTypeInterface>(
879  t1, t2);
880 }
881 
882 //===----------------------------------------------------------------------===//
883 // CollectMatchingOp
884 //===----------------------------------------------------------------------===//
885 
887 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
889  transform::TransformState &state) {
890  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
891  getOperation(), getMatcher());
892  if (matcher.isExternal()) {
893  return emitDefiniteFailure()
894  << "unresolved external symbol " << getMatcher();
895  }
896 
898  rawResults.resize(getOperation()->getNumResults());
899  std::optional<DiagnosedSilenceableFailure> maybeFailure;
900  for (Operation *root : state.getPayloadOps(getRoot())) {
901  WalkResult walkResult = root->walk([&](Operation *op) {
902  DEBUG_MATCHER({
903  DBGS_MATCHER() << "matching ";
904  op->print(llvm::dbgs(),
905  OpPrintingFlags().assumeVerified().skipRegions());
906  llvm::dbgs() << " @" << op << "\n";
907  });
908 
909  // Try matching.
912  matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
913  if (diag.isDefiniteFailure())
914  return WalkResult::interrupt();
915  if (diag.isSilenceableFailure()) {
916  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
917  << " failed: " << diag.getMessage());
918  return WalkResult::advance();
919  }
920 
921  // If succeeded, collect results.
922  for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
923  if (mapping.size() != 1) {
924  maybeFailure.emplace(emitSilenceableError()
925  << "result #" << i << ", associated with "
926  << mapping.size()
927  << " payload objects, expected 1");
928  return WalkResult::interrupt();
929  }
930  rawResults[i].push_back(mapping[0]);
931  }
932  return WalkResult::advance();
933  });
934  if (walkResult.wasInterrupted())
935  return std::move(*maybeFailure);
936  assert(!maybeFailure && "failure set but the walk was not interrupted");
937 
938  for (auto &&[opResult, rawResult] :
939  llvm::zip_equal(getOperation()->getResults(), rawResults)) {
940  results.setMappedValues(opResult, rawResult);
941  }
942  }
944 }
945 
946 void transform::CollectMatchingOp::getEffects(
947  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
948  onlyReadsHandle(getRoot(), effects);
949  producesHandle(getResults(), effects);
950  onlyReadsPayload(effects);
951 }
952 
953 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
954  SymbolTableCollection &symbolTable) {
955  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
956  symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
957  if (!matcherSymbol ||
958  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
959  return emitError() << "unresolved matcher symbol " << getMatcher();
960 
961  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
962  if (argumentTypes.size() != 1 ||
963  !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
964  return emitError()
965  << "expected the matcher to take one operation handle argument";
966  }
967  if (!matcherSymbol.getArgAttr(
968  0, transform::TransformDialect::kArgReadOnlyAttrName)) {
969  return emitError() << "expected the matcher argument to be marked readonly";
970  }
971 
972  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
973  if (resultTypes.size() != getOperation()->getNumResults()) {
974  return emitError()
975  << "expected the matcher to yield as many values as op has results ("
976  << getOperation()->getNumResults() << "), got "
977  << resultTypes.size();
978  }
979 
980  for (auto &&[i, matcherType, resultType] :
981  llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
982  if (implementSameTransformInterface(matcherType, resultType))
983  continue;
984 
985  return emitError()
986  << "mismatching type interfaces for matcher result and op result #"
987  << i;
988  }
989 
990  return success();
991 }
992 
993 //===----------------------------------------------------------------------===//
994 // ForeachMatchOp
995 //===----------------------------------------------------------------------===//
996 
998 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1000  transform::TransformState &state) {
1002  matchActionPairs;
1003  matchActionPairs.reserve(getMatchers().size());
1004  SymbolTableCollection symbolTable;
1005  for (auto &&[matcher, action] :
1006  llvm::zip_equal(getMatchers(), getActions())) {
1007  auto matcherSymbol =
1008  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1009  getOperation(), cast<SymbolRefAttr>(matcher));
1010  auto actionSymbol =
1011  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1012  getOperation(), cast<SymbolRefAttr>(action));
1013  assert(matcherSymbol && actionSymbol &&
1014  "unresolved symbols not caught by the verifier");
1015 
1016  if (matcherSymbol.isExternal())
1017  return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1018  if (actionSymbol.isExternal())
1019  return emitDefiniteFailure() << "unresolved external symbol " << action;
1020 
1021  matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1022  }
1023 
1024  DiagnosedSilenceableFailure overallDiag =
1026  for (Operation *root : state.getPayloadOps(getRoot())) {
1027  WalkResult walkResult = root->walk([&](Operation *op) {
1028  // If getRestrictRoot is not present, skip over the root op itself so we
1029  // don't invalidate it.
1030  if (!getRestrictRoot() && op == root)
1031  return WalkResult::advance();
1032 
1033  DEBUG_MATCHER({
1034  DBGS_MATCHER() << "matching ";
1035  op->print(llvm::dbgs(),
1036  OpPrintingFlags().assumeVerified().skipRegions());
1037  llvm::dbgs() << " @" << op << "\n";
1038  });
1039 
1040  // Try all the match/action pairs until the first successful match.
1041  for (auto [matcher, action] : matchActionPairs) {
1044  matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
1045  if (diag.isDefiniteFailure())
1046  return WalkResult::interrupt();
1047  if (diag.isSilenceableFailure()) {
1048  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1049  << " failed: " << diag.getMessage());
1050  continue;
1051  }
1052 
1053  auto scope = state.make_region_scope(action.getFunctionBody());
1054  for (auto &&[arg, map] : llvm::zip_equal(
1055  action.getFunctionBody().front().getArguments(), mappings)) {
1056  if (failed(state.mapBlockArgument(arg, map)))
1057  return WalkResult::interrupt();
1058  }
1059 
1060  for (Operation &transform :
1061  action.getFunctionBody().front().without_terminator()) {
1063  state.applyTransform(cast<TransformOpInterface>(transform));
1064  if (result.isDefiniteFailure())
1065  return WalkResult::interrupt();
1066  if (result.isSilenceableFailure()) {
1067  if (overallDiag.succeeded()) {
1068  overallDiag = emitSilenceableError() << "actions failed";
1069  }
1070  overallDiag.attachNote(action->getLoc())
1071  << "failed action: " << result.getMessage();
1072  overallDiag.attachNote(op->getLoc())
1073  << "when applied to this matching payload";
1074  (void)result.silence();
1075  continue;
1076  }
1077  }
1078  break;
1079  }
1080  return WalkResult::advance();
1081  });
1082  if (walkResult.wasInterrupted())
1084  }
1085 
1086  // The root operation should not have been affected, so we can just reassign
1087  // the payload to the result. Note that we need to consume the root handle to
1088  // make sure any handles to operations inside, that could have been affected
1089  // by actions, are invalidated.
1090  results.set(llvm::cast<OpResult>(getUpdated()),
1091  state.getPayloadOps(getRoot()));
1092  return overallDiag;
1093 }
1094 
1095 void transform::ForeachMatchOp::getEffects(
1096  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1097  // Bail if invalid.
1098  if (getOperation()->getNumOperands() < 1 ||
1099  getOperation()->getNumResults() < 1) {
1100  return modifiesPayload(effects);
1101  }
1102 
1103  consumesHandle(getRoot(), effects);
1104  producesHandle(getUpdated(), effects);
1105  modifiesPayload(effects);
1106 }
1107 
1108 /// Parses the comma-separated list of symbol reference pairs of the format
1109 /// `@matcher -> @action`.
1111  ArrayAttr &matchers,
1112  ArrayAttr &actions) {
1113  StringAttr matcher;
1114  StringAttr action;
1115  SmallVector<Attribute> matcherList;
1116  SmallVector<Attribute> actionList;
1117  do {
1118  if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1119  parser.parseSymbolName(action)) {
1120  return failure();
1121  }
1122  matcherList.push_back(SymbolRefAttr::get(matcher));
1123  actionList.push_back(SymbolRefAttr::get(action));
1124  } while (parser.parseOptionalComma().succeeded());
1125 
1126  matchers = parser.getBuilder().getArrayAttr(matcherList);
1127  actions = parser.getBuilder().getArrayAttr(actionList);
1128  return success();
1129 }
1130 
1131 /// Prints the comma-separated list of symbol reference pairs of the format
1132 /// `@matcher -> @action`.
1134  ArrayAttr matchers, ArrayAttr actions) {
1135  printer.increaseIndent();
1136  printer.increaseIndent();
1137  for (auto &&[matcher, action, idx] : llvm::zip_equal(
1138  matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1139  printer.printNewline();
1140  printer << cast<SymbolRefAttr>(matcher) << " -> "
1141  << cast<SymbolRefAttr>(action);
1142  if (idx != matchers.size() - 1)
1143  printer << ", ";
1144  }
1145  printer.decreaseIndent();
1146  printer.decreaseIndent();
1147 }
1148 
1150  if (getMatchers().size() != getActions().size())
1151  return emitOpError() << "expected the same number of matchers and actions";
1152  if (getMatchers().empty())
1153  return emitOpError() << "expected at least one match/action pair";
1154 
1155  llvm::SmallPtrSet<Attribute, 8> matcherNames;
1156  for (Attribute name : getMatchers()) {
1157  if (matcherNames.insert(name).second)
1158  continue;
1159  emitWarning() << "matcher " << name
1160  << " is used more than once, only the first match will apply";
1161  }
1162 
1163  return success();
1164 }
1165 
1166 /// Checks that the attributes of the function-like operation have correct
1167 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
1168 /// annotations being present even if they can be inferred from the body.
1170 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1171  bool alsoVerifyInternal = false) {
1172  auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1173  llvm::SmallDenseSet<unsigned> consumedArguments;
1174  if (!op.isExternal()) {
1175  transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1176  consumedArguments);
1177  }
1178  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1179  bool isConsumed =
1180  op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1181  nullptr;
1182  bool isReadOnly =
1183  op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1184  nullptr;
1185  if (isConsumed && isReadOnly) {
1186  return transformOp.emitSilenceableError()
1187  << "argument #" << i << " cannot be both readonly and consumed";
1188  }
1189  if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1190  return transformOp.emitSilenceableError()
1191  << "must provide consumed/readonly status for arguments of "
1192  "external or called ops";
1193  }
1194  if (op.isExternal())
1195  continue;
1196 
1197  if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1198  return transformOp.emitSilenceableError()
1199  << "argument #" << i
1200  << " is consumed in the body but is not marked as such";
1201  }
1202  if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1203  // Cannot use op.emitWarning() here as it would attempt to verify the op
1204  // before printing, resulting in infinite recursion.
1205  emitWarning(op->getLoc())
1206  << "op argument #" << i
1207  << " is not consumed in the body but is marked as consumed";
1208  }
1209  }
1211 }
1212 
1213 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1214  SymbolTableCollection &symbolTable) {
1215  assert(getMatchers().size() == getActions().size());
1216  auto consumedAttr =
1217  StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1218  for (auto &&[matcher, action] :
1219  llvm::zip_equal(getMatchers(), getActions())) {
1220  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1221  symbolTable.lookupNearestSymbolFrom(getOperation(),
1222  cast<SymbolRefAttr>(matcher)));
1223  auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1224  symbolTable.lookupNearestSymbolFrom(getOperation(),
1225  cast<SymbolRefAttr>(action)));
1226  if (!matcherSymbol ||
1227  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1228  return emitError() << "unresolved matcher symbol " << matcher;
1229  if (!actionSymbol ||
1230  !isa<TransformOpInterface>(actionSymbol.getOperation()))
1231  return emitError() << "unresolved action symbol " << action;
1232 
1233  if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1234  /*emitWarnings=*/false,
1235  /*alsoVerifyInternal=*/true)
1236  .checkAndReport())) {
1237  return failure();
1238  }
1240  /*emitWarnings=*/false,
1241  /*alsoVerifyInternal=*/true)
1242  .checkAndReport())) {
1243  return failure();
1244  }
1245 
1246  ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes();
1247  ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
1248  if (matcherResults.size() != actionArguments.size()) {
1249  return emitError() << "mismatching number of matcher results and "
1250  "action arguments between "
1251  << matcher << " (" << matcherResults.size() << ") and "
1252  << action << " (" << actionArguments.size() << ")";
1253  }
1254  for (auto &&[i, matcherType, actionType] :
1255  llvm::enumerate(matcherResults, actionArguments)) {
1256  if (implementSameTransformInterface(matcherType, actionType))
1257  continue;
1258 
1259  return emitError() << "mismatching type interfaces for matcher result "
1260  "and action argument #"
1261  << i;
1262  }
1263 
1264  if (!actionSymbol.getResultTypes().empty()) {
1266  emitError() << "action symbol is not expected to have results";
1267  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1268  return diag;
1269  }
1270 
1271  if (matcherSymbol.getArgumentTypes().size() != 1 ||
1272  !implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0],
1273  getRoot().getType())) {
1275  emitOpError() << "expects matcher symbol to have one argument with "
1276  "the same transform interface as the first operand";
1277  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1278  return diag;
1279  }
1280 
1281  if (matcherSymbol.getArgAttr(0, consumedAttr)) {
1283  emitOpError()
1284  << "does not expect matcher symbol to consume its operand";
1285  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1286  return diag;
1287  }
1288  }
1289  return success();
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // ForeachOp
1294 //===----------------------------------------------------------------------===//
1295 
1297 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1298  transform::TransformResults &results,
1299  transform::TransformState &state) {
1300  SmallVector<SmallVector<Operation *>> resultOps(getNumResults(), {});
1301  // Store payload ops in a vector because ops may be removed from the mapping
1302  // by the TrackingRewriter while the iteration is in progress.
1303  SmallVector<Operation *> targets =
1304  llvm::to_vector(state.getPayloadOps(getTarget()));
1305  for (Operation *op : targets) {
1306  auto scope = state.make_region_scope(getBody());
1307  if (failed(state.mapBlockArguments(getIterationVariable(), {op})))
1309 
1310  // Execute loop body.
1311  for (Operation &transform : getBody().front().without_terminator()) {
1312  DiagnosedSilenceableFailure result = state.applyTransform(
1313  cast<transform::TransformOpInterface>(transform));
1314  if (!result.succeeded())
1315  return result;
1316  }
1317 
1318  // Append yielded payload ops to result list (if any).
1319  for (unsigned i = 0; i < getNumResults(); ++i) {
1320  auto yieldedOps = state.getPayloadOps(getYieldOp().getOperand(i));
1321  resultOps[i].append(yieldedOps.begin(), yieldedOps.end());
1322  }
1323  }
1324 
1325  for (unsigned i = 0; i < getNumResults(); ++i)
1326  results.set(llvm::cast<OpResult>(getResult(i)), resultOps[i]);
1327 
1329 }
1330 
1331 void transform::ForeachOp::getEffects(
1332  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1333  BlockArgument iterVar = getIterationVariable();
1334  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1335  return isHandleConsumed(iterVar, cast<TransformOpInterface>(&op));
1336  })) {
1337  consumesHandle(getTarget(), effects);
1338  } else {
1339  onlyReadsHandle(getTarget(), effects);
1340  }
1341 
1342  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1343  return doesModifyPayload(cast<TransformOpInterface>(&op));
1344  })) {
1345  modifiesPayload(effects);
1346  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1347  return doesReadPayload(cast<TransformOpInterface>(&op));
1348  })) {
1349  onlyReadsPayload(effects);
1350  }
1351 
1352  for (Value result : getResults())
1353  producesHandle(result, effects);
1354 }
1355 
1356 void transform::ForeachOp::getSuccessorRegions(
1357  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1358  Region *bodyRegion = &getBody();
1359  if (point.isParent()) {
1360  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1361  return;
1362  }
1363 
1364  // Branch back to the region or the parent.
1365  assert(point == getBody() && "unexpected region index");
1366  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1367  regions.emplace_back();
1368 }
1369 
1371 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1372  // The iteration variable op handle is mapped to a subset (one op to be
1373  // precise) of the payload ops of the ForeachOp operand.
1374  assert(point == getBody() && "unexpected region index");
1375  return getOperation()->getOperands();
1376 }
1377 
1378 transform::YieldOp transform::ForeachOp::getYieldOp() {
1379  return cast<transform::YieldOp>(getBody().front().getTerminator());
1380 }
1381 
1383  auto yieldOp = getYieldOp();
1384  if (getNumResults() != yieldOp.getNumOperands())
1385  return emitOpError() << "expects the same number of results as the "
1386  "terminator has operands";
1387  for (Value v : yieldOp.getOperands())
1388  if (!llvm::isa<TransformHandleTypeInterface>(v.getType()))
1389  return yieldOp->emitOpError("expects operands to have types implementing "
1390  "TransformHandleTypeInterface");
1391  return success();
1392 }
1393 
1394 //===----------------------------------------------------------------------===//
1395 // GetParentOp
1396 //===----------------------------------------------------------------------===//
1397 
1399 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1400  transform::TransformResults &results,
1401  transform::TransformState &state) {
1402  SmallVector<Operation *> parents;
1403  DenseSet<Operation *> resultSet;
1404  for (Operation *target : state.getPayloadOps(getTarget())) {
1405  Operation *parent = target;
1406  for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1407  parent = parent->getParentOp();
1408  while (parent) {
1409  bool checkIsolatedFromAbove =
1410  !getIsolatedFromAbove() ||
1412  bool checkOpName = !getOpName().has_value() ||
1413  parent->getName().getStringRef() == *getOpName();
1414  if (checkIsolatedFromAbove && checkOpName)
1415  break;
1416  parent = parent->getParentOp();
1417  }
1418  if (!parent) {
1419  if (getAllowEmptyResults()) {
1420  results.set(llvm::cast<OpResult>(getResult()), parents);
1422  }
1424  emitSilenceableError()
1425  << "could not find a parent op that matches all requirements";
1426  diag.attachNote(target->getLoc()) << "target op";
1427  return diag;
1428  }
1429  }
1430  if (getDeduplicate()) {
1431  if (!resultSet.contains(parent)) {
1432  parents.push_back(parent);
1433  resultSet.insert(parent);
1434  }
1435  } else {
1436  parents.push_back(parent);
1437  }
1438  }
1439  results.set(llvm::cast<OpResult>(getResult()), parents);
1441 }
1442 
1443 //===----------------------------------------------------------------------===//
1444 // GetConsumersOfResult
1445 //===----------------------------------------------------------------------===//
1446 
1448 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1449  transform::TransformResults &results,
1450  transform::TransformState &state) {
1451  int64_t resultNumber = getResultNumber();
1452  auto payloadOps = state.getPayloadOps(getTarget());
1453  if (std::empty(payloadOps)) {
1454  results.set(cast<OpResult>(getResult()), {});
1456  }
1457  if (!llvm::hasSingleElement(payloadOps))
1458  return emitDefiniteFailure()
1459  << "handle must be mapped to exactly one payload op";
1460 
1461  Operation *target = *payloadOps.begin();
1462  if (target->getNumResults() <= resultNumber)
1463  return emitDefiniteFailure() << "result number overflow";
1464  results.set(llvm::cast<OpResult>(getResult()),
1465  llvm::to_vector(target->getResult(resultNumber).getUsers()));
1467 }
1468 
1469 //===----------------------------------------------------------------------===//
1470 // GetDefiningOp
1471 //===----------------------------------------------------------------------===//
1472 
1474 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1475  transform::TransformResults &results,
1476  transform::TransformState &state) {
1477  SmallVector<Operation *> definingOps;
1478  for (Value v : state.getPayloadValues(getTarget())) {
1479  if (llvm::isa<BlockArgument>(v)) {
1481  emitSilenceableError() << "cannot get defining op of block argument";
1482  diag.attachNote(v.getLoc()) << "target value";
1483  return diag;
1484  }
1485  definingOps.push_back(v.getDefiningOp());
1486  }
1487  results.set(llvm::cast<OpResult>(getResult()), definingOps);
1489 }
1490 
1491 //===----------------------------------------------------------------------===//
1492 // GetProducerOfOperand
1493 //===----------------------------------------------------------------------===//
1494 
1496 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1497  transform::TransformResults &results,
1498  transform::TransformState &state) {
1499  int64_t operandNumber = getOperandNumber();
1500  SmallVector<Operation *> producers;
1501  for (Operation *target : state.getPayloadOps(getTarget())) {
1502  Operation *producer =
1503  target->getNumOperands() <= operandNumber
1504  ? nullptr
1505  : target->getOperand(operandNumber).getDefiningOp();
1506  if (!producer) {
1508  emitSilenceableError()
1509  << "could not find a producer for operand number: " << operandNumber
1510  << " of " << *target;
1511  diag.attachNote(target->getLoc()) << "target op";
1512  return diag;
1513  }
1514  producers.push_back(producer);
1515  }
1516  results.set(llvm::cast<OpResult>(getResult()), producers);
1518 }
1519 
1520 //===----------------------------------------------------------------------===//
1521 // GetOperandOp
1522 //===----------------------------------------------------------------------===//
1523 
1525 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1526  transform::TransformResults &results,
1527  transform::TransformState &state) {
1528  SmallVector<Value> operands;
1529  for (Operation *target : state.getPayloadOps(getTarget())) {
1530  SmallVector<int64_t> operandPositions;
1532  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1533  target->getNumOperands(), operandPositions);
1534  if (diag.isSilenceableFailure()) {
1535  diag.attachNote(target->getLoc())
1536  << "while considering positions of this payload operation";
1537  return diag;
1538  }
1539  llvm::append_range(operands,
1540  llvm::map_range(operandPositions, [&](int64_t pos) {
1541  return target->getOperand(pos);
1542  }));
1543  }
1544  results.setValues(cast<OpResult>(getResult()), operands);
1546 }
1547 
1549  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1550  getIsInverted(), getIsAll());
1551 }
1552 
1553 //===----------------------------------------------------------------------===//
1554 // GetResultOp
1555 //===----------------------------------------------------------------------===//
1556 
1558 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1559  transform::TransformResults &results,
1560  transform::TransformState &state) {
1561  SmallVector<Value> opResults;
1562  for (Operation *target : state.getPayloadOps(getTarget())) {
1563  SmallVector<int64_t> resultPositions;
1565  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1566  target->getNumResults(), resultPositions);
1567  if (diag.isSilenceableFailure()) {
1568  diag.attachNote(target->getLoc())
1569  << "while considering positions of this payload operation";
1570  return diag;
1571  }
1572  llvm::append_range(opResults,
1573  llvm::map_range(resultPositions, [&](int64_t pos) {
1574  return target->getResult(pos);
1575  }));
1576  }
1577  results.setValues(cast<OpResult>(getResult()), opResults);
1579 }
1580 
1582  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1583  getIsInverted(), getIsAll());
1584 }
1585 
1586 //===----------------------------------------------------------------------===//
1587 // GetTypeOp
1588 //===----------------------------------------------------------------------===//
1589 
1590 void transform::GetTypeOp::getEffects(
1591  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1592  onlyReadsHandle(getValue(), effects);
1593  producesHandle(getResult(), effects);
1594  onlyReadsPayload(effects);
1595 }
1596 
1598 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1599  transform::TransformResults &results,
1600  transform::TransformState &state) {
1601  SmallVector<Attribute> params;
1602  for (Value value : state.getPayloadValues(getValue())) {
1603  Type type = value.getType();
1604  if (getElemental()) {
1605  if (auto shaped = dyn_cast<ShapedType>(type)) {
1606  type = shaped.getElementType();
1607  }
1608  }
1609  params.push_back(TypeAttr::get(type));
1610  }
1611  results.setParams(getResult().cast<OpResult>(), params);
1613 }
1614 
1615 //===----------------------------------------------------------------------===//
1616 // IncludeOp
1617 //===----------------------------------------------------------------------===//
1618 
1619 /// Applies the transform ops contained in `block`. Maps `results` to the same
1620 /// values as the operands of the block terminator.
1622 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1624  transform::TransformResults &results) {
1625  // Apply the sequenced ops one by one.
1626  for (Operation &transform : block.without_terminator()) {
1628  state.applyTransform(cast<transform::TransformOpInterface>(transform));
1629  if (result.isDefiniteFailure())
1630  return result;
1631 
1632  if (result.isSilenceableFailure()) {
1633  if (mode == transform::FailurePropagationMode::Propagate) {
1634  // Propagate empty results in case of early exit.
1635  forwardEmptyOperands(&block, state, results);
1636  return result;
1637  }
1638  (void)result.silence();
1639  }
1640  }
1641 
1642  // Forward the operation mapping for values yielded from the sequence to the
1643  // values produced by the sequence op.
1644  transform::detail::forwardTerminatorOperands(&block, state, results);
1646 }
1647 
1649 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
1650  transform::TransformResults &results,
1651  transform::TransformState &state) {
1652  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1653  getOperation(), getTarget());
1654  assert(callee && "unverified reference to unknown symbol");
1655 
1656  if (callee.isExternal())
1657  return emitDefiniteFailure() << "unresolved external named sequence";
1658 
1659  // Map operands to block arguments.
1661  detail::prepareValueMappings(mappings, getOperands(), state);
1662  auto scope = state.make_region_scope(callee.getBody());
1663  for (auto &&[arg, map] :
1664  llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1665  if (failed(state.mapBlockArgument(arg, map)))
1667  }
1668 
1670  callee.getBody().front(), getFailurePropagationMode(), state, results);
1671  mappings.clear();
1673  mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1674  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1675  results.setMappedValues(result, mapping);
1676  return result;
1677 }
1678 
1680 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
1681 
1682 void transform::IncludeOp::getEffects(
1683  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1684  // Always mark as modifying the payload.
1685  // TODO: a mechanism to annotate effects on payload. Even when all handles are
1686  // only read, the payload may still be modified, so we currently stay on the
1687  // conservative side and always indicate modification. This may prevent some
1688  // code reordering.
1689  modifiesPayload(effects);
1690 
1691  // Results are always produced.
1692  producesHandle(getResults(), effects);
1693 
1694  // Adds default effects to operands and results. This will be added if
1695  // preconditions fail so the trait verifier doesn't complain about missing
1696  // effects and the real precondition failure is reported later on.
1697  auto defaultEffects = [&] { onlyReadsHandle(getOperands(), effects); };
1698 
1699  // Bail if the callee is unknown. This may run as part of the verification
1700  // process before we verified the validity of the callee or of this op.
1701  auto target =
1702  getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1703  if (!target)
1704  return defaultEffects();
1705  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1706  getOperation(), getTarget());
1707  if (!callee)
1708  return defaultEffects();
1709  DiagnosedSilenceableFailure earlyVerifierResult =
1710  verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
1711  if (!earlyVerifierResult.succeeded()) {
1712  (void)earlyVerifierResult.silence();
1713  return defaultEffects();
1714  }
1715 
1716  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1717  if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1718  consumesHandle(getOperand(i), effects);
1719  else
1720  onlyReadsHandle(getOperand(i), effects);
1721  }
1722 }
1723 
1725 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1726  // Access through indirection and do additional checking because this may be
1727  // running before the main op verifier.
1728  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
1729  if (!targetAttr)
1730  return emitOpError() << "expects a 'target' symbol reference attribute";
1731 
1732  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
1733  *this, targetAttr);
1734  if (!target)
1735  return emitOpError() << "does not reference a named transform sequence";
1736 
1737  FunctionType fnType = target.getFunctionType();
1738  if (fnType.getNumInputs() != getNumOperands())
1739  return emitError("incorrect number of operands for callee");
1740 
1741  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1742  if (getOperand(i).getType() != fnType.getInput(i)) {
1743  return emitOpError("operand type mismatch: expected operand type ")
1744  << fnType.getInput(i) << ", but provided "
1745  << getOperand(i).getType() << " for operand number " << i;
1746  }
1747  }
1748 
1749  if (fnType.getNumResults() != getNumResults())
1750  return emitError("incorrect number of results for callee");
1751 
1752  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1753  Type resultType = getResult(i).getType();
1754  Type funcType = fnType.getResult(i);
1755  if (!implementSameTransformInterface(resultType, funcType)) {
1756  return emitOpError() << "type of result #" << i
1757  << " must implement the same transform dialect "
1758  "interface as the corresponding callee result";
1759  }
1760  }
1761 
1763  cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
1764  /*alsoVerifyInternal=*/true)
1765  .checkAndReport();
1766 }
1767 
1768 //===----------------------------------------------------------------------===//
1769 // MatchOperationEmptyOp
1770 //===----------------------------------------------------------------------===//
1771 
1772 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
1773  ::std::optional<::mlir::Operation *> maybeCurrent,
1775  if (!maybeCurrent.has_value()) {
1776  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
1778  }
1779  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
1780  return emitSilenceableError() << "operation is not empty";
1781 }
1782 
1783 //===----------------------------------------------------------------------===//
1784 // MatchOperationNameOp
1785 //===----------------------------------------------------------------------===//
1786 
1787 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
1788  Operation *current, transform::TransformResults &results,
1789  transform::TransformState &state) {
1790  StringRef currentOpName = current->getName().getStringRef();
1791  for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1792  if (acceptedAttr.getValue() == currentOpName)
1794  }
1795  return emitSilenceableError() << "wrong operation name";
1796 }
1797 
1798 //===----------------------------------------------------------------------===//
1799 // MatchParamCmpIOp
1800 //===----------------------------------------------------------------------===//
1801 
1803 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
1804  transform::TransformResults &results,
1805  transform::TransformState &state) {
1806  auto signedAPIntAsString = [&](const APInt &value) {
1807  std::string str;
1808  llvm::raw_string_ostream os(str);
1809  value.print(os, /*isSigned=*/true);
1810  return os.str();
1811  };
1812 
1813  ArrayRef<Attribute> params = state.getParams(getParam());
1814  ArrayRef<Attribute> references = state.getParams(getReference());
1815 
1816  if (params.size() != references.size()) {
1817  return emitSilenceableError()
1818  << "parameters have different payload lengths (" << params.size()
1819  << " vs " << references.size() << ")";
1820  }
1821 
1822  for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
1823  auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1824  auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1825  if (!intAttr || !refAttr) {
1826  return emitDefiniteFailure()
1827  << "non-integer parameter value not expected";
1828  }
1829  if (intAttr.getType() != refAttr.getType()) {
1830  return emitDefiniteFailure()
1831  << "mismatching integer attribute types in parameter #" << i;
1832  }
1833  APInt value = intAttr.getValue();
1834  APInt refValue = refAttr.getValue();
1835 
1836  // TODO: this copy will not be necessary in C++20.
1837  int64_t position = i;
1838  auto reportError = [&](StringRef direction) {
1840  emitSilenceableError() << "expected parameter to be " << direction
1841  << " " << signedAPIntAsString(refValue)
1842  << ", got " << signedAPIntAsString(value);
1843  diag.attachNote(getParam().getLoc())
1844  << "value # " << position
1845  << " associated with the parameter defined here";
1846  return diag;
1847  };
1848 
1849  switch (getPredicate()) {
1850  case MatchCmpIPredicate::eq:
1851  if (value.eq(refValue))
1852  break;
1853  return reportError("equal to");
1854  case MatchCmpIPredicate::ne:
1855  if (value.ne(refValue))
1856  break;
1857  return reportError("not equal to");
1858  case MatchCmpIPredicate::lt:
1859  if (value.slt(refValue))
1860  break;
1861  return reportError("less than");
1862  case MatchCmpIPredicate::le:
1863  if (value.sle(refValue))
1864  break;
1865  return reportError("less than or equal to");
1866  case MatchCmpIPredicate::gt:
1867  if (value.sgt(refValue))
1868  break;
1869  return reportError("greater than");
1870  case MatchCmpIPredicate::ge:
1871  if (value.sge(refValue))
1872  break;
1873  return reportError("greater than or equal to");
1874  }
1875  }
1877 }
1878 
1879 void transform::MatchParamCmpIOp::getEffects(
1880  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1881  onlyReadsHandle(getParam(), effects);
1882  onlyReadsHandle(getReference(), effects);
1883 }
1884 
1885 //===----------------------------------------------------------------------===//
1886 // ParamConstantOp
1887 //===----------------------------------------------------------------------===//
1888 
1890 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
1891  transform::TransformResults &results,
1892  transform::TransformState &state) {
1893  results.setParams(cast<OpResult>(getParam()), {getValue()});
1895 }
1896 
1897 //===----------------------------------------------------------------------===//
1898 // MergeHandlesOp
1899 //===----------------------------------------------------------------------===//
1900 
1902 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
1903  transform::TransformResults &results,
1904  transform::TransformState &state) {
1905  ValueRange handles = getHandles();
1906  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
1907  SmallVector<Operation *> operations;
1908  for (Value operand : handles)
1909  llvm::append_range(operations, state.getPayloadOps(operand));
1910  if (!getDeduplicate()) {
1911  results.set(llvm::cast<OpResult>(getResult()), operations);
1913  }
1914 
1915  SetVector<Operation *> uniqued(operations.begin(), operations.end());
1916  results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
1918  }
1919 
1920  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
1921  SmallVector<Attribute> attrs;
1922  for (Value attribute : handles)
1923  llvm::append_range(attrs, state.getParams(attribute));
1924  if (!getDeduplicate()) {
1925  results.setParams(cast<OpResult>(getResult()), attrs);
1927  }
1928 
1929  SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
1930  results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
1932  }
1933 
1934  assert(
1935  llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
1936  "expected value handle type");
1937  SmallVector<Value> payloadValues;
1938  for (Value value : handles)
1939  llvm::append_range(payloadValues, state.getPayloadValues(value));
1940  if (!getDeduplicate()) {
1941  results.setValues(cast<OpResult>(getResult()), payloadValues);
1943  }
1944 
1945  SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
1946  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
1948 }
1949 
1950 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
1951  // Handles may be the same if deduplicating is enabled.
1952  return getDeduplicate();
1953 }
1954 
1955 void transform::MergeHandlesOp::getEffects(
1956  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1957  onlyReadsHandle(getHandles(), effects);
1958  producesHandle(getResult(), effects);
1959 
1960  // There are no effects on the Payload IR as this is only a handle
1961  // manipulation.
1962 }
1963 
1964 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
1965  if (getDeduplicate() || getHandles().size() != 1)
1966  return {};
1967 
1968  // If deduplication is not required and there is only one operand, it can be
1969  // used directly instead of merging.
1970  return getHandles().front();
1971 }
1972 
1973 //===----------------------------------------------------------------------===//
1974 // NamedSequenceOp
1975 //===----------------------------------------------------------------------===//
1976 
1978 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
1979  transform::TransformResults &results,
1980  transform::TransformState &state) {
1981  if (isExternal())
1982  return emitDefiniteFailure() << "unresolved external named sequence";
1983 
1984  // Map the entry block argument to the list of operations.
1985  // Note: this is the same implementation as PossibleTopLevelTransformOp but
1986  // without attaching the interface / trait since that is tailored to a
1987  // dangling top-level op that does not get "called".
1988  auto scope = state.make_region_scope(getBody());
1990  state, this->getOperation(), getBody())))
1992 
1993  return applySequenceBlock(getBody().front(),
1994  FailurePropagationMode::Propagate, state, results);
1995 }
1996 
1997 void transform::NamedSequenceOp::getEffects(
1998  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
1999 
2001  OperationState &result) {
2003  parser, result, /*allowVariadic=*/false,
2004  getFunctionTypeAttrName(result.name),
2005  [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2007  std::string &) { return builder.getFunctionType(inputs, results); },
2008  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2009 }
2010 
2013  printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2014  getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2015  getResAttrsAttrName());
2016 }
2017 
2018 /// Verifies that a symbol function-like transform dialect operation has the
2019 /// signature and the terminator that have conforming types, i.e., types
2020 /// implementing the same transform dialect type interface. If `allowExternal`
2021 /// is set, allow external symbols (declarations) and don't check the terminator
2022 /// as it may not exist.
2024 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2025  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2028  << "cannot be defined inside another transform op";
2029  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2030  return diag;
2031  }
2032 
2033  if (op.isExternal() || op.getFunctionBody().empty()) {
2034  if (allowExternal)
2036 
2037  return emitSilenceableFailure(op) << "cannot be external";
2038  }
2039 
2040  if (op.getFunctionBody().front().empty())
2041  return emitSilenceableFailure(op) << "expected a non-empty body block";
2042 
2043  Operation *terminator = &op.getFunctionBody().front().back();
2044  if (!isa<transform::YieldOp>(terminator)) {
2046  << "expected '"
2047  << transform::YieldOp::getOperationName()
2048  << "' as terminator";
2049  diag.attachNote(terminator->getLoc()) << "terminator";
2050  return diag;
2051  }
2052 
2053  if (terminator->getNumOperands() != op.getResultTypes().size()) {
2054  return emitSilenceableFailure(terminator)
2055  << "expected terminator to have as many operands as the parent op "
2056  "has results";
2057  }
2058  for (auto [i, operandType, resultType] : llvm::zip_equal(
2059  llvm::seq<unsigned>(0, terminator->getNumOperands()),
2060  terminator->getOperands().getType(), op.getResultTypes())) {
2061  if (operandType == resultType)
2062  continue;
2063  return emitSilenceableFailure(terminator)
2064  << "the type of the terminator operand #" << i
2065  << " must match the type of the corresponding parent op result ("
2066  << operandType << " vs " << resultType << ")";
2067  }
2068 
2070 }
2071 
2072 /// Verification of a NamedSequenceOp. This does not report the error
2073 /// immediately, so it can be used to check for op's well-formedness before the
2074 /// verifier runs, e.g., during trait verification.
2076 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2077  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2078  if (!parent->getAttr(
2079  transform::TransformDialect::kWithNamedSequenceAttrName)) {
2082  << "expects the parent symbol table to have the '"
2083  << transform::TransformDialect::kWithNamedSequenceAttrName
2084  << "' attribute";
2085  diag.attachNote(parent->getLoc()) << "symbol table operation";
2086  return diag;
2087  }
2088  }
2089 
2090  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2093  << "cannot be defined inside another transform op";
2094  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2095  return diag;
2096  }
2097 
2098  if (op.isExternal() || op.getBody().empty())
2099  return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2100  emitWarnings);
2101 
2102  if (op.getBody().front().empty())
2103  return emitSilenceableFailure(op) << "expected a non-empty body block";
2104 
2105  Operation *terminator = &op.getBody().front().back();
2106  if (!isa<transform::YieldOp>(terminator)) {
2108  << "expected '"
2109  << transform::YieldOp::getOperationName()
2110  << "' as terminator";
2111  diag.attachNote(terminator->getLoc()) << "terminator";
2112  return diag;
2113  }
2114 
2115  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2116  return emitSilenceableFailure(terminator)
2117  << "expected terminator to have as many operands as the parent op "
2118  "has results";
2119  }
2120  for (auto [i, operandType, resultType] :
2121  llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2122  terminator->getOperands().getType(),
2123  op.getFunctionType().getResults())) {
2124  if (operandType == resultType)
2125  continue;
2126  return emitSilenceableFailure(terminator)
2127  << "the type of the terminator operand #" << i
2128  << " must match the type of the corresponding parent op result ("
2129  << operandType << " vs " << resultType << ")";
2130  }
2131 
2132  auto funcOp = cast<FunctionOpInterface>(*op);
2134  verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2135  if (!diag.succeeded())
2136  return diag;
2137 
2138  return verifyYieldingSingleBlockOp(funcOp,
2139  /*allowExternal=*/true);
2140 }
2141 
2143  // Actual verification happens in a separate function for reusability.
2144  return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2145 }
2146 
2147 template <typename FnTy>
2148 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2149  Type bbArgType, TypeRange extraBindingTypes,
2150  FnTy bodyBuilder) {
2151  SmallVector<Type> types;
2152  types.reserve(1 + extraBindingTypes.size());
2153  types.push_back(bbArgType);
2154  llvm::append_range(types, extraBindingTypes);
2155 
2156  OpBuilder::InsertionGuard guard(builder);
2157  Region *region = state.regions.back().get();
2158  Block *bodyBlock =
2159  builder.createBlock(region, region->begin(), types,
2160  SmallVector<Location>(types.size(), state.location));
2161 
2162  // Populate body.
2163  builder.setInsertionPointToStart(bodyBlock);
2164  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2165  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2166  } else {
2167  bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2168  bodyBlock->getArguments().drop_front());
2169  }
2170 }
2171 
2172 void transform::NamedSequenceOp::build(OpBuilder &builder,
2173  OperationState &state, StringRef symName,
2174  Type rootType, TypeRange resultTypes,
2175  SequenceBodyBuilderFn bodyBuilder,
2177  ArrayRef<DictionaryAttr> argAttrs) {
2178  state.addAttribute(SymbolTable::getSymbolAttrName(),
2179  builder.getStringAttr(symName));
2180  state.addAttribute(getFunctionTypeAttrName(state.name),
2182  rootType, resultTypes)));
2183  state.attributes.append(attrs.begin(), attrs.end());
2184  state.addRegion();
2185 
2186  buildSequenceBody(builder, state, rootType,
2187  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2188 }
2189 
2190 //===----------------------------------------------------------------------===//
2191 // NumAssociationsOp
2192 //===----------------------------------------------------------------------===//
2193 
2195 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2196  transform::TransformResults &results,
2197  transform::TransformState &state) {
2198  size_t numAssociations =
2199  llvm::TypeSwitch<Type, size_t>(getHandle().getType())
2200  .Case([&](TransformHandleTypeInterface opHandle) {
2201  return llvm::range_size(state.getPayloadOps(getHandle()));
2202  })
2203  .Case([&](TransformValueHandleTypeInterface valueHandle) {
2204  return llvm::range_size(state.getPayloadValues(getHandle()));
2205  })
2206  .Case([&](TransformParamTypeInterface param) {
2207  return llvm::range_size(state.getParams(getHandle()));
2208  })
2209  .Default([](Type) {
2210  llvm_unreachable("unknown kind of transform dialect type");
2211  return 0;
2212  });
2213  results.setParams(getNum().cast<OpResult>(),
2214  rewriter.getI64IntegerAttr(numAssociations));
2216 }
2217 
2219  // Verify that the result type accepts an i64 attribute as payload.
2220  auto resultType = getNum().getType().cast<TransformParamTypeInterface>();
2221  return resultType
2222  .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2223  .checkAndReport();
2224 }
2225 
2226 //===----------------------------------------------------------------------===//
2227 // SelectOp
2228 //===----------------------------------------------------------------------===//
2229 
2231 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2232  transform::TransformResults &results,
2233  transform::TransformState &state) {
2234  SmallVector<Operation *> result;
2235  auto payloadOps = state.getPayloadOps(getTarget());
2236  for (Operation *op : payloadOps) {
2237  if (op->getName().getStringRef() == getOpName())
2238  result.push_back(op);
2239  }
2240  results.set(cast<OpResult>(getResult()), result);
2242 }
2243 
2244 //===----------------------------------------------------------------------===//
2245 // SplitHandleOp
2246 //===----------------------------------------------------------------------===//
2247 
2248 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2249  Value target, int64_t numResultHandles) {
2250  result.addOperands(target);
2251  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2252 }
2253 
2255 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2256  transform::TransformResults &results,
2257  transform::TransformState &state) {
2258  int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
2259  auto produceNumOpsError = [&]() {
2260  return emitSilenceableError()
2261  << getHandle() << " expected to contain " << this->getNumResults()
2262  << " payload ops but it contains " << numPayloadOps
2263  << " payload ops";
2264  };
2265 
2266  // Fail if there are more payload ops than results and no overflow result was
2267  // specified.
2268  if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
2269  return produceNumOpsError();
2270 
2271  // Fail if there are more results than payload ops. Unless:
2272  // - "fail_on_payload_too_small" is set to "false", or
2273  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2274  if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
2275  (numPayloadOps != 0 || !getPassThroughEmptyHandle()))
2276  return produceNumOpsError();
2277 
2278  // Distribute payload ops.
2279  SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
2280  if (getOverflowResult())
2281  resultHandles[*getOverflowResult()].reserve(numPayloadOps -
2282  getNumResults());
2283  for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
2284  int64_t resultNum = en.index();
2285  if (resultNum >= getNumResults())
2286  resultNum = *getOverflowResult();
2287  resultHandles[resultNum].push_back(en.value());
2288  }
2289 
2290  // Set transform op results.
2291  for (auto &&it : llvm::enumerate(resultHandles))
2292  results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
2293 
2295 }
2296 
2297 void transform::SplitHandleOp::getEffects(
2298  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2299  onlyReadsHandle(getHandle(), effects);
2300  producesHandle(getResults(), effects);
2301  // There are no effects on the Payload IR as this is only a handle
2302  // manipulation.
2303 }
2304 
2306  if (getOverflowResult().has_value() &&
2307  !(*getOverflowResult() < getNumResults()))
2308  return emitOpError("overflow_result is not a valid result index");
2309  return success();
2310 }
2311 
2312 //===----------------------------------------------------------------------===//
2313 // ReplicateOp
2314 //===----------------------------------------------------------------------===//
2315 
2317 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2318  transform::TransformResults &results,
2319  transform::TransformState &state) {
2320  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2321  for (const auto &en : llvm::enumerate(getHandles())) {
2322  Value handle = en.value();
2323  if (isa<TransformHandleTypeInterface>(handle.getType())) {
2324  SmallVector<Operation *> current =
2325  llvm::to_vector(state.getPayloadOps(handle));
2326  SmallVector<Operation *> payload;
2327  payload.reserve(numRepetitions * current.size());
2328  for (unsigned i = 0; i < numRepetitions; ++i)
2329  llvm::append_range(payload, current);
2330  results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2331  } else {
2332  assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2333  "expected param type");
2334  ArrayRef<Attribute> current = state.getParams(handle);
2335  SmallVector<Attribute> params;
2336  params.reserve(numRepetitions * current.size());
2337  for (unsigned i = 0; i < numRepetitions; ++i)
2338  llvm::append_range(params, current);
2339  results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2340  params);
2341  }
2342  }
2344 }
2345 
2346 void transform::ReplicateOp::getEffects(
2347  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2348  onlyReadsHandle(getPattern(), effects);
2349  onlyReadsHandle(getHandles(), effects);
2350  producesHandle(getReplicated(), effects);
2351 }
2352 
2353 //===----------------------------------------------------------------------===//
2354 // SequenceOp
2355 //===----------------------------------------------------------------------===//
2356 
2358 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2359  transform::TransformResults &results,
2360  transform::TransformState &state) {
2361  // Map the entry block argument to the list of operations.
2362  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2363  if (failed(mapBlockArguments(state)))
2365 
2366  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2367  results);
2368 }
2369 
2371  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2372  Type &rootType,
2373  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2374  SmallVectorImpl<Type> &extraBindingTypes) {
2375  OpAsmParser::UnresolvedOperand rootOperand;
2376  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2377  if (!hasRoot.has_value()) {
2378  root = std::nullopt;
2379  return success();
2380  }
2381  if (failed(hasRoot.value()))
2382  return failure();
2383  root = rootOperand;
2384 
2385  if (succeeded(parser.parseOptionalComma())) {
2386  if (failed(parser.parseOperandList(extraBindings)))
2387  return failure();
2388  }
2389  if (failed(parser.parseColon()))
2390  return failure();
2391 
2392  // The paren is truly optional.
2393  (void)parser.parseOptionalLParen();
2394 
2395  if (failed(parser.parseType(rootType))) {
2396  return failure();
2397  }
2398 
2399  if (!extraBindings.empty()) {
2400  if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2401  return failure();
2402  }
2403 
2404  if (extraBindingTypes.size() != extraBindings.size()) {
2405  return parser.emitError(parser.getNameLoc(),
2406  "expected types to be provided for all operands");
2407  }
2408 
2409  // The paren is truly optional.
2410  (void)parser.parseOptionalRParen();
2411  return success();
2412 }
2413 
2415  Value root, Type rootType,
2416  ValueRange extraBindings,
2417  TypeRange extraBindingTypes) {
2418  if (!root)
2419  return;
2420 
2421  printer << root;
2422  bool hasExtras = !extraBindings.empty();
2423  if (hasExtras) {
2424  printer << ", ";
2425  printer.printOperands(extraBindings);
2426  }
2427 
2428  printer << " : ";
2429  if (hasExtras)
2430  printer << "(";
2431 
2432  printer << rootType;
2433  if (hasExtras) {
2434  printer << ", ";
2435  llvm::interleaveComma(extraBindingTypes, printer.getStream());
2436  printer << ")";
2437  }
2438 }
2439 
2440 /// Returns `true` if the given op operand may be consuming the handle value in
2441 /// the Transform IR. That is, if it may have a Free effect on it.
2443  // Conservatively assume the effect being present in absence of the interface.
2444  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2445  if (!iface)
2446  return true;
2447 
2448  return isHandleConsumed(use.get(), iface);
2449 }
2450 
2453  function_ref<InFlightDiagnostic()> reportError) {
2454  OpOperand *potentialConsumer = nullptr;
2455  for (OpOperand &use : value.getUses()) {
2456  if (!isValueUsePotentialConsumer(use))
2457  continue;
2458 
2459  if (!potentialConsumer) {
2460  potentialConsumer = &use;
2461  continue;
2462  }
2463 
2464  InFlightDiagnostic diag = reportError()
2465  << " has more than one potential consumer";
2466  diag.attachNote(potentialConsumer->getOwner()->getLoc())
2467  << "used here as operand #" << potentialConsumer->getOperandNumber();
2468  diag.attachNote(use.getOwner()->getLoc())
2469  << "used here as operand #" << use.getOperandNumber();
2470  return diag;
2471  }
2472 
2473  return success();
2474 }
2475 
2477  assert(getBodyBlock()->getNumArguments() >= 1 &&
2478  "the number of arguments must have been verified to be more than 1 by "
2479  "PossibleTopLevelTransformOpTrait");
2480 
2481  if (!getRoot() && !getExtraBindings().empty()) {
2482  return emitOpError()
2483  << "does not expect extra operands when used as top-level";
2484  }
2485 
2486  // Check if a block argument has more than one consuming use.
2487  for (BlockArgument arg : getBodyBlock()->getArguments()) {
2488  if (failed(checkDoubleConsume(arg, [this, arg]() {
2489  return (emitOpError() << "block argument #" << arg.getArgNumber());
2490  }))) {
2491  return failure();
2492  }
2493  }
2494 
2495  // Check properties of the nested operations they cannot check themselves.
2496  for (Operation &child : *getBodyBlock()) {
2497  if (!isa<TransformOpInterface>(child) &&
2498  &child != &getBodyBlock()->back()) {
2500  emitOpError()
2501  << "expected children ops to implement TransformOpInterface";
2502  diag.attachNote(child.getLoc()) << "op without interface";
2503  return diag;
2504  }
2505 
2506  for (OpResult result : child.getResults()) {
2507  auto report = [&]() {
2508  return (child.emitError() << "result #" << result.getResultNumber());
2509  };
2510  if (failed(checkDoubleConsume(result, report)))
2511  return failure();
2512  }
2513  }
2514 
2515  if (!getBodyBlock()->mightHaveTerminator())
2516  return emitOpError() << "expects to have a terminator in the body";
2517 
2518  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2519  getOperation()->getResultTypes()) {
2520  InFlightDiagnostic diag = emitOpError()
2521  << "expects the types of the terminator operands "
2522  "to match the types of the result";
2523  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2524  return diag;
2525  }
2526  return success();
2527 }
2528 
2529 void transform::SequenceOp::getEffects(
2530  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2531  getPotentialTopLevelEffects(effects);
2532 }
2533 
2535 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2536  assert(point == getBody() && "unexpected region index");
2537  if (getOperation()->getNumOperands() > 0)
2538  return getOperation()->getOperands();
2539  return OperandRange(getOperation()->operand_end(),
2540  getOperation()->operand_end());
2541 }
2542 
2543 void transform::SequenceOp::getSuccessorRegions(
2544  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2545  if (point.isParent()) {
2546  Region *bodyRegion = &getBody();
2547  regions.emplace_back(bodyRegion, getNumOperands() != 0
2548  ? bodyRegion->getArguments()
2550  return;
2551  }
2552 
2553  assert(point == getBody() && "unexpected region index");
2554  regions.emplace_back(getOperation()->getResults());
2555 }
2556 
2557 void transform::SequenceOp::getRegionInvocationBounds(
2558  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2559  (void)operands;
2560  bounds.emplace_back(1, 1);
2561 }
2562 
2563 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2564  TypeRange resultTypes,
2565  FailurePropagationMode failurePropagationMode,
2566  Value root,
2567  SequenceBodyBuilderFn bodyBuilder) {
2568  build(builder, state, resultTypes, failurePropagationMode, root,
2569  /*extra_bindings=*/ValueRange());
2570  Type bbArgType = root.getType();
2571  buildSequenceBody(builder, state, bbArgType,
2572  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2573 }
2574 
2575 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2576  TypeRange resultTypes,
2577  FailurePropagationMode failurePropagationMode,
2578  Value root, ValueRange extraBindings,
2579  SequenceBodyBuilderArgsFn bodyBuilder) {
2580  build(builder, state, resultTypes, failurePropagationMode, root,
2581  extraBindings);
2582  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2583  bodyBuilder);
2584 }
2585 
2586 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2587  TypeRange resultTypes,
2588  FailurePropagationMode failurePropagationMode,
2589  Type bbArgType,
2590  SequenceBodyBuilderFn bodyBuilder) {
2591  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2592  /*extra_bindings=*/ValueRange());
2593  buildSequenceBody(builder, state, bbArgType,
2594  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2595 }
2596 
2597 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2598  TypeRange resultTypes,
2599  FailurePropagationMode failurePropagationMode,
2600  Type bbArgType, TypeRange extraBindingTypes,
2601  SequenceBodyBuilderArgsFn bodyBuilder) {
2602  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2603  /*extra_bindings=*/ValueRange());
2604  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
2605 }
2606 
2607 //===----------------------------------------------------------------------===//
2608 // PrintOp
2609 //===----------------------------------------------------------------------===//
2610 
2611 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2612  StringRef name) {
2613  if (!name.empty())
2614  result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
2615 }
2616 
2617 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2618  Value target, StringRef name) {
2619  result.addOperands({target});
2620  build(builder, result, name);
2621 }
2622 
2624 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
2625  transform::TransformResults &results,
2626  transform::TransformState &state) {
2627  llvm::outs() << "[[[ IR printer: ";
2628  if (getName().has_value())
2629  llvm::outs() << *getName() << " ";
2630 
2631  OpPrintingFlags printFlags;
2632  if (getAssumeVerified().value_or(false))
2633  printFlags.assumeVerified();
2634  if (getUseLocalScope().value_or(false))
2635  printFlags.useLocalScope();
2636  if (getSkipRegions().value_or(false))
2637  printFlags.skipRegions();
2638 
2639  if (!getTarget()) {
2640  llvm::outs() << "top-level ]]]\n";
2641  state.getTopLevel()->print(llvm::outs(), printFlags);
2642  llvm::outs() << "\n";
2644  }
2645 
2646  llvm::outs() << "]]]\n";
2647  for (Operation *target : state.getPayloadOps(getTarget())) {
2648  target->print(llvm::outs(), printFlags);
2649  llvm::outs() << "\n";
2650  }
2651 
2653 }
2654 
2655 void transform::PrintOp::getEffects(
2656  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2657  // We don't really care about mutability here, but `getTarget` now
2658  // unconditionally casts to a specific type before verification could run
2659  // here.
2660  if (!getTargetMutable().empty())
2661  onlyReadsHandle(getTargetMutable()[0].get(), effects);
2662  onlyReadsPayload(effects);
2663 
2664  // There is no resource for stderr file descriptor, so just declare print
2665  // writes into the default resource.
2666  effects.emplace_back(MemoryEffects::Write::get());
2667 }
2668 
2669 //===----------------------------------------------------------------------===//
2670 // VerifyOp
2671 //===----------------------------------------------------------------------===//
2672 
2674 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
2675  Operation *target,
2677  transform::TransformState &state) {
2678  if (failed(::mlir::verify(target))) {
2680  << "failed to verify payload op";
2681  diag.attachNote(target->getLoc()) << "payload op";
2682  return diag;
2683  }
2685 }
2686 
2687 void transform::VerifyOp::getEffects(
2688  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2689  transform::onlyReadsHandle(getTarget(), effects);
2690 }
2691 
2692 //===----------------------------------------------------------------------===//
2693 // YieldOp
2694 //===----------------------------------------------------------------------===//
2695 
2696 void transform::YieldOp::getEffects(
2697  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2698  onlyReadsHandle(getOperands(), effects);
2699 }
static MLIRContext * getContext(OpFoldResult val)
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs)
Checks that two types are the same or can be cast into one another.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DBGS_MATCHER()
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
#define DEBUG_MATCHER(x)
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
#define DBGS()
static DiagnosedSilenceableFailure matchBlock(Block &block, Operation *op, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings)
Applies matcher operations from the given block assigning op as the payload of the block's first argu...
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:76
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:134
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:206
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
A class for computing basic dominance information.
Definition: Dominance.h:136
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:308
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & assumeVerified()
Do not verify the operation when using custom operation printers.
Definition: AsmPrinter.cpp:273
OpPrintingFlags & useLocalScope()
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:281
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:267
This is a value defined by a result of an operation.
Definition: Value.h:453
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
type_range getTypes() const
Definition: ValueRange.cpp:26
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents success/failure for parsing-like operations that find it important to chain tog...
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
Definition: PassManager.h:232
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:49
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
Definition: PassRegistry.h:55
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This is a "type erased" representation of a registered operation.
MLIRContext * getContext() const
Definition: PatternMatch.h:822
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:340
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
size_t size() const
Return the size of this range.
Definition: TypeRange.h:145
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
Definition: Value.h:125
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:208
user_range getUsers() const
Definition: Value.h:224
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
static WalkResult interrupt()
Definition: Visitors.h:51
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument, ::mlir::ValueRange)> SequenceBodyBuilderArgsFn
Definition: TransformOps.h:39
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)> SequenceBodyBuilderFn
A builder function that populates the body of a SequenceOp.
Definition: TransformOps.h:36
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition: CSE.cpp:382
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool succeeded() const
Returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:41
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.