MLIR  21.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Dominance.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/Verifier.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Pass/PassRegistry.h"
33 #include "mlir/Transforms/CSE.h"
37 #include "llvm/ADT/DenseSet.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/ScopeExit.h"
40 #include "llvm/ADT/SmallPtrSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/ErrorHandling.h"
44 #include "llvm/Support/InterleavedRange.h"
45 #include <optional>
46 
47 #define DEBUG_TYPE "transform-dialect"
48 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
49 
50 #define DEBUG_TYPE_MATCHER "transform-matcher"
51 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
52 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
53 
54 using namespace mlir;
55 
56 static ParseResult parseApplyRegisteredPassOptions(
57  OpAsmParser &parser, DictionaryAttr &options,
58  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
60  Operation *op,
61  DictionaryAttr options,
62  ValueRange dynamicOptions);
63 static ParseResult parseSequenceOpOperands(
64  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
65  Type &rootType,
66  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
67  SmallVectorImpl<Type> &extraBindingTypes);
68 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
69  Value root, Type rootType,
70  ValueRange extraBindings,
71  TypeRange extraBindingTypes);
72 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
73  ArrayAttr matchers, ArrayAttr actions);
74 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
75  ArrayAttr &matchers,
76  ArrayAttr &actions);
77 
78 /// Helper function to check if the given transform op is contained in (or
79 /// equal to) the given payload target op. In that case, an error is returned.
80 /// Transforming transform IR that is currently executing is generally unsafe.
82 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
83  Operation *payload) {
84  Operation *transformAncestor = transform.getOperation();
85  while (transformAncestor) {
86  if (transformAncestor == payload) {
88  transform.emitDefiniteFailure()
89  << "cannot apply transform to itself (or one of its ancestors)";
90  diag.attachNote(payload->getLoc()) << "target payload op";
91  return diag;
92  }
93  transformAncestor = transformAncestor->getParentOp();
94  }
96 }
97 
98 #define GET_OP_CLASSES
99 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
100 
101 //===----------------------------------------------------------------------===//
102 // AlternativesOp
103 //===----------------------------------------------------------------------===//
104 
106 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
107  if (!point.isParent() && getOperation()->getNumOperands() == 1)
108  return getOperation()->getOperands();
109  return OperandRange(getOperation()->operand_end(),
110  getOperation()->operand_end());
111 }
112 
113 void transform::AlternativesOp::getSuccessorRegions(
114  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
115  for (Region &alternative : llvm::drop_begin(
116  getAlternatives(),
117  point.isParent() ? 0
118  : point.getRegionOrNull()->getRegionNumber() + 1)) {
119  regions.emplace_back(&alternative, !getOperands().empty()
120  ? alternative.getArguments()
122  }
123  if (!point.isParent())
124  regions.emplace_back(getOperation()->getResults());
125 }
126 
127 void transform::AlternativesOp::getRegionInvocationBounds(
128  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
129  (void)operands;
130  // The region corresponding to the first alternative is always executed, the
131  // remaining may or may not be executed.
132  bounds.reserve(getNumRegions());
133  bounds.emplace_back(1, 1);
134  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
135 }
136 
138  transform::TransformResults &results) {
139  for (const auto &res : block->getParentOp()->getOpResults())
140  results.set(res, {});
141 }
142 
144 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
146  transform::TransformState &state) {
147  SmallVector<Operation *> originals;
148  if (Value scopeHandle = getScope())
149  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
150  else
151  originals.push_back(state.getTopLevel());
152 
153  for (Operation *original : originals) {
154  if (original->isAncestor(getOperation())) {
155  auto diag = emitDefiniteFailure()
156  << "scope must not contain the transforms being applied";
157  diag.attachNote(original->getLoc()) << "scope";
158  return diag;
159  }
160  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
161  auto diag = emitDefiniteFailure()
162  << "only isolated-from-above ops can be alternative scopes";
163  diag.attachNote(original->getLoc()) << "scope";
164  return diag;
165  }
166  }
167 
168  for (Region &reg : getAlternatives()) {
169  // Clone the scope operations and make the transforms in this alternative
170  // region apply to them by virtue of mapping the block argument (the only
171  // visible handle) to the cloned scope operations. This effectively prevents
172  // the transformation from accessing any IR outside the scope.
173  auto scope = state.make_region_scope(reg);
174  auto clones = llvm::to_vector(
175  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
176  auto deleteClones = llvm::make_scope_exit([&] {
177  for (Operation *clone : clones)
178  clone->erase();
179  });
180  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
182 
183  bool failed = false;
184  for (Operation &transform : reg.front().without_terminator()) {
186  state.applyTransform(cast<TransformOpInterface>(transform));
187  if (result.isSilenceableFailure()) {
188  LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
189  << "\n");
190  failed = true;
191  break;
192  }
193 
194  if (::mlir::failed(result.silence()))
196  }
197 
198  // If all operations in the given alternative succeeded, no need to consider
199  // the rest. Replace the original scoping operation with the clone on which
200  // the transformations were performed.
201  if (!failed) {
202  // We will be using the clones, so cancel their scheduled deletion.
203  deleteClones.release();
204  TrackingListener listener(state, *this);
205  IRRewriter rewriter(getContext(), &listener);
206  for (const auto &kvp : llvm::zip(originals, clones)) {
207  Operation *original = std::get<0>(kvp);
208  Operation *clone = std::get<1>(kvp);
209  original->getBlock()->getOperations().insert(original->getIterator(),
210  clone);
211  rewriter.replaceOp(original, clone->getResults());
212  }
213  detail::forwardTerminatorOperands(&reg.front(), state, results);
215  }
216  }
217  return emitSilenceableError() << "all alternatives failed";
218 }
219 
220 void transform::AlternativesOp::getEffects(
221  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
222  consumesHandle(getOperation()->getOpOperands(), effects);
223  producesHandle(getOperation()->getOpResults(), effects);
224  for (Region *region : getRegions()) {
225  if (!region->empty())
226  producesHandle(region->front().getArguments(), effects);
227  }
228  modifiesPayload(effects);
229 }
230 
231 LogicalResult transform::AlternativesOp::verify() {
232  for (Region &alternative : getAlternatives()) {
233  Block &block = alternative.front();
234  Operation *terminator = block.getTerminator();
235  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
236  InFlightDiagnostic diag = emitOpError()
237  << "expects terminator operands to have the "
238  "same type as results of the operation";
239  diag.attachNote(terminator->getLoc()) << "terminator";
240  return diag;
241  }
242  }
243 
244  return success();
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // AnnotateOp
249 //===----------------------------------------------------------------------===//
250 
252 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
254  transform::TransformState &state) {
255  SmallVector<Operation *> targets =
256  llvm::to_vector(state.getPayloadOps(getTarget()));
257 
259  if (auto paramH = getParam()) {
260  ArrayRef<Attribute> params = state.getParams(paramH);
261  if (params.size() != 1) {
262  if (targets.size() != params.size()) {
263  return emitSilenceableError()
264  << "parameter and target have different payload lengths ("
265  << params.size() << " vs " << targets.size() << ")";
266  }
267  for (auto &&[target, attr] : llvm::zip_equal(targets, params))
268  target->setAttr(getName(), attr);
270  }
271  attr = params[0];
272  }
273  for (auto *target : targets)
274  target->setAttr(getName(), attr);
276 }
277 
278 void transform::AnnotateOp::getEffects(
279  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
280  onlyReadsHandle(getTargetMutable(), effects);
281  onlyReadsHandle(getParamMutable(), effects);
282  modifiesPayload(effects);
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // ApplyCommonSubexpressionEliminationOp
287 //===----------------------------------------------------------------------===//
288 
290 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
291  transform::TransformRewriter &rewriter, Operation *target,
292  ApplyToEachResultList &results, transform::TransformState &state) {
293  // Make sure that this transform is not applied to itself. Modifying the
294  // transform IR while it is being interpreted is generally dangerous.
295  DiagnosedSilenceableFailure payloadCheck =
297  if (!payloadCheck.succeeded())
298  return payloadCheck;
299 
300  DominanceInfo domInfo;
301  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
303 }
304 
305 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
306  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
307  transform::onlyReadsHandle(getTargetMutable(), effects);
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // ApplyDeadCodeEliminationOp
313 //===----------------------------------------------------------------------===//
314 
315 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
316  transform::TransformRewriter &rewriter, Operation *target,
317  ApplyToEachResultList &results, transform::TransformState &state) {
318  // Make sure that this transform is not applied to itself. Modifying the
319  // transform IR while it is being interpreted is generally dangerous.
320  DiagnosedSilenceableFailure payloadCheck =
322  if (!payloadCheck.succeeded())
323  return payloadCheck;
324 
325  // Maintain a worklist of potentially dead ops.
326  SetVector<Operation *> worklist;
327 
328  // Helper function that adds all defining ops of used values (operands and
329  // operands of nested ops).
330  auto addDefiningOpsToWorklist = [&](Operation *op) {
331  op->walk([&](Operation *op) {
332  for (Value v : op->getOperands())
333  if (Operation *defOp = v.getDefiningOp())
334  if (target->isProperAncestor(defOp))
335  worklist.insert(defOp);
336  });
337  };
338 
339  // Helper function that erases an op.
340  auto eraseOp = [&](Operation *op) {
341  // Remove op and nested ops from the worklist.
342  op->walk([&](Operation *op) {
343  const auto *it = llvm::find(worklist, op);
344  if (it != worklist.end())
345  worklist.erase(it);
346  });
347  rewriter.eraseOp(op);
348  };
349 
350  // Initial walk over the IR.
351  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
352  if (op != target && isOpTriviallyDead(op)) {
353  addDefiningOpsToWorklist(op);
354  eraseOp(op);
355  }
356  });
357 
358  // Erase all ops that have become dead.
359  while (!worklist.empty()) {
360  Operation *op = worklist.pop_back_val();
361  if (!isOpTriviallyDead(op))
362  continue;
363  addDefiningOpsToWorklist(op);
364  eraseOp(op);
365  }
366 
368 }
369 
370 void transform::ApplyDeadCodeEliminationOp::getEffects(
371  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
372  transform::onlyReadsHandle(getTargetMutable(), effects);
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // ApplyPatternsOp
378 //===----------------------------------------------------------------------===//
379 
380 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
381  transform::TransformRewriter &rewriter, Operation *target,
382  ApplyToEachResultList &results, transform::TransformState &state) {
383  // Make sure that this transform is not applied to itself. Modifying the
384  // transform IR while it is being interpreted is generally dangerous. Even
385  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
386  // performs many additional simplifications such as dead code elimination.
387  DiagnosedSilenceableFailure payloadCheck =
389  if (!payloadCheck.succeeded())
390  return payloadCheck;
391 
392  // Gather all specified patterns.
393  MLIRContext *ctx = target->getContext();
395  if (!getRegion().empty()) {
396  for (Operation &op : getRegion().front()) {
397  cast<transform::PatternDescriptorOpInterface>(&op)
398  .populatePatternsWithState(patterns, state);
399  }
400  }
401 
402  // Configure the GreedyPatternRewriteDriver.
404  config.setListener(
405  static_cast<RewriterBase::Listener *>(rewriter.getListener()));
406  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
407 
408  config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
410  : getMaxIterations());
411  config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
413  : getMaxNumRewrites());
414 
415  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
416  // was requested, apply the greedy pattern rewrite only once. (The greedy
417  // pattern rewrite driver already iterates to a fixpoint internally.)
418  bool cseChanged = false;
419  // One or two iterations should be sufficient. Stop iterating after a certain
420  // threshold to make debugging easier.
421  static const int64_t kNumMaxIterations = 50;
422  int64_t iteration = 0;
423  do {
424  LogicalResult result = failure();
425  if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
426  // Op is isolated from above. Apply patterns and also perform region
427  // simplification.
428  result = applyPatternsGreedily(target, frozenPatterns, config);
429  } else {
430  // Manually gather list of ops because the other
431  // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
432  // from above. This way, patterns can be applied to ops that are not
433  // isolated from above. Regions are not being simplified. Furthermore,
434  // only a single greedy rewrite iteration is performed.
436  target->walk([&](Operation *nestedOp) {
437  if (target != nestedOp)
438  ops.push_back(nestedOp);
439  });
440  result = applyOpPatternsGreedily(ops, frozenPatterns, config);
441  }
442 
443  // A failure typically indicates that the pattern application did not
444  // converge.
445  if (failed(result)) {
446  return emitSilenceableFailure(target)
447  << "greedy pattern application failed";
448  }
449 
450  if (getApplyCse()) {
451  DominanceInfo domInfo;
452  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
453  &cseChanged);
454  }
455  } while (cseChanged && ++iteration < kNumMaxIterations);
456 
457  if (iteration == kNumMaxIterations)
458  return emitDefiniteFailure() << "fixpoint iteration did not converge";
459 
461 }
462 
463 LogicalResult transform::ApplyPatternsOp::verify() {
464  if (!getRegion().empty()) {
465  for (Operation &op : getRegion().front()) {
466  if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
467  InFlightDiagnostic diag = emitOpError()
468  << "expected children ops to implement "
469  "PatternDescriptorOpInterface";
470  diag.attachNote(op.getLoc()) << "op without interface";
471  return diag;
472  }
473  }
474  }
475  return success();
476 }
477 
478 void transform::ApplyPatternsOp::getEffects(
479  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
480  transform::onlyReadsHandle(getTargetMutable(), effects);
482 }
483 
484 void transform::ApplyPatternsOp::build(
485  OpBuilder &builder, OperationState &result, Value target,
486  function_ref<void(OpBuilder &, Location)> bodyBuilder) {
487  result.addOperands(target);
488 
489  OpBuilder::InsertionGuard g(builder);
490  Region *region = result.addRegion();
491  builder.createBlock(region);
492  if (bodyBuilder)
493  bodyBuilder(builder, result.location);
494 }
495 
496 //===----------------------------------------------------------------------===//
497 // ApplyCanonicalizationPatternsOp
498 //===----------------------------------------------------------------------===//
499 
500 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
502  MLIRContext *ctx = patterns.getContext();
503  for (Dialect *dialect : ctx->getLoadedDialects())
504  dialect->getCanonicalizationPatterns(patterns);
506  op.getCanonicalizationPatterns(patterns, ctx);
507 }
508 
509 //===----------------------------------------------------------------------===//
510 // ApplyConversionPatternsOp
511 //===----------------------------------------------------------------------===//
512 
513 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
516  MLIRContext *ctx = getContext();
517 
518  // Instantiate the default type converter if a type converter builder is
519  // specified.
520  std::unique_ptr<TypeConverter> defaultTypeConverter;
521  transform::TypeConverterBuilderOpInterface typeConverterBuilder =
522  getDefaultTypeConverter();
523  if (typeConverterBuilder)
524  defaultTypeConverter = typeConverterBuilder.getTypeConverter();
525 
526  // Configure conversion target.
527  ConversionTarget conversionTarget(*getContext());
528  if (getLegalOps())
529  for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
530  conversionTarget.addLegalOp(
531  OperationName(cast<StringAttr>(attr).getValue(), ctx));
532  if (getIllegalOps())
533  for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
534  conversionTarget.addIllegalOp(
535  OperationName(cast<StringAttr>(attr).getValue(), ctx));
536  if (getLegalDialects())
537  for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
538  conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
539  if (getIllegalDialects())
540  for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
541  conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
542 
543  // Gather all specified patterns.
545  // Need to keep the converters alive until after pattern application because
546  // the patterns take a reference to an object that would otherwise get out of
547  // scope.
548  SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
549  if (!getPatterns().empty()) {
550  for (Operation &op : getPatterns().front()) {
551  auto descriptor =
552  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
553 
554  // Check if this pattern set specifies a type converter.
555  std::unique_ptr<TypeConverter> typeConverter =
556  descriptor.getTypeConverter();
557  TypeConverter *converter = nullptr;
558  if (typeConverter) {
559  keepAliveConverters.emplace_back(std::move(typeConverter));
560  converter = keepAliveConverters.back().get();
561  } else {
562  // No type converter specified: Use the default type converter.
563  if (!defaultTypeConverter) {
564  auto diag = emitDefiniteFailure()
565  << "pattern descriptor does not specify type "
566  "converter and apply_conversion_patterns op has "
567  "no default type converter";
568  diag.attachNote(op.getLoc()) << "pattern descriptor op";
569  return diag;
570  }
571  converter = defaultTypeConverter.get();
572  }
573 
574  // Add descriptor-specific updates to the conversion target, which may
575  // depend on the final type converter. In structural converters, the
576  // legality of types dictates the dynamic legality of an operation.
577  descriptor.populateConversionTargetRules(*converter, conversionTarget);
578 
579  descriptor.populatePatterns(*converter, patterns);
580  }
581  }
582 
583  // Attach a tracking listener if handles should be preserved. We configure the
584  // listener to allow op replacements with different names, as conversion
585  // patterns typically replace ops with replacement ops that have a different
586  // name.
587  TrackingListenerConfig trackingConfig;
588  trackingConfig.requireMatchingReplacementOpName = false;
589  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
590  ConversionConfig conversionConfig;
591  if (getPreserveHandles())
592  conversionConfig.listener = &trackingListener;
593 
594  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
595  for (Operation *target : state.getPayloadOps(getTarget())) {
596  // Make sure that this transform is not applied to itself. Modifying the
597  // transform IR while it is being interpreted is generally dangerous.
598  DiagnosedSilenceableFailure payloadCheck =
600  if (!payloadCheck.succeeded())
601  return payloadCheck;
602 
603  LogicalResult status = failure();
604  if (getPartialConversion()) {
605  status = applyPartialConversion(target, conversionTarget, frozenPatterns,
606  conversionConfig);
607  } else {
608  status = applyFullConversion(target, conversionTarget, frozenPatterns,
609  conversionConfig);
610  }
611 
612  // Check dialect conversion state.
614  if (failed(status)) {
615  diag = emitSilenceableError() << "dialect conversion failed";
616  diag.attachNote(target->getLoc()) << "target op";
617  }
618 
619  // Check tracking listener error state.
620  DiagnosedSilenceableFailure trackingFailure =
621  trackingListener.checkAndResetError();
622  if (!trackingFailure.succeeded()) {
623  if (diag.succeeded()) {
624  // Tracking failure is the only failure.
625  return trackingFailure;
626  } else {
627  diag.attachNote() << "tracking listener also failed: "
628  << trackingFailure.getMessage();
629  (void)trackingFailure.silence();
630  }
631  }
632 
633  if (!diag.succeeded())
634  return diag;
635  }
636 
638 }
639 
641  if (getNumRegions() != 1 && getNumRegions() != 2)
642  return emitOpError() << "expected 1 or 2 regions";
643  if (!getPatterns().empty()) {
644  for (Operation &op : getPatterns().front()) {
645  if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
647  emitOpError() << "expected pattern children ops to implement "
648  "ConversionPatternDescriptorOpInterface";
649  diag.attachNote(op.getLoc()) << "op without interface";
650  return diag;
651  }
652  }
653  }
654  if (getNumRegions() == 2) {
655  Region &typeConverterRegion = getRegion(1);
656  if (!llvm::hasSingleElement(typeConverterRegion.front()))
657  return emitOpError()
658  << "expected exactly one op in default type converter region";
659  Operation *maybeTypeConverter = &typeConverterRegion.front().front();
660  auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
661  maybeTypeConverter);
662  if (!typeConverterOp) {
663  InFlightDiagnostic diag = emitOpError()
664  << "expected default converter child op to "
665  "implement TypeConverterBuilderOpInterface";
666  diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
667  return diag;
668  }
669  // Check default type converter type.
670  if (!getPatterns().empty()) {
671  for (Operation &op : getPatterns().front()) {
672  auto descriptor =
673  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
674  if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
675  return failure();
676  }
677  }
678  }
679  return success();
680 }
681 
682 void transform::ApplyConversionPatternsOp::getEffects(
683  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
684  if (!getPreserveHandles()) {
685  transform::consumesHandle(getTargetMutable(), effects);
686  } else {
687  transform::onlyReadsHandle(getTargetMutable(), effects);
688  }
690 }
691 
692 void transform::ApplyConversionPatternsOp::build(
693  OpBuilder &builder, OperationState &result, Value target,
694  function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
695  function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
696  result.addOperands(target);
697 
698  {
699  OpBuilder::InsertionGuard g(builder);
700  Region *region1 = result.addRegion();
701  builder.createBlock(region1);
702  if (patternsBodyBuilder)
703  patternsBodyBuilder(builder, result.location);
704  }
705  {
706  OpBuilder::InsertionGuard g(builder);
707  Region *region2 = result.addRegion();
708  builder.createBlock(region2);
709  if (typeConverterBodyBuilder)
710  typeConverterBodyBuilder(builder, result.location);
711  }
712 }
713 
714 //===----------------------------------------------------------------------===//
715 // ApplyToLLVMConversionPatternsOp
716 //===----------------------------------------------------------------------===//
717 
718 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
719  TypeConverter &typeConverter, RewritePatternSet &patterns) {
720  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
721  assert(dialect && "expected that dialect is loaded");
722  auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
723  // ConversionTarget is currently ignored because the enclosing
724  // apply_conversion_patterns op sets up its own ConversionTarget.
725  ConversionTarget target(*getContext());
726  iface->populateConvertToLLVMConversionPatterns(
727  target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
728 }
729 
730 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
731  transform::TypeConverterBuilderOpInterface builder) {
732  if (builder.getTypeConverterType() != "LLVMTypeConverter")
733  return emitOpError("expected LLVMTypeConverter");
734  return success();
735 }
736 
738  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
739  if (!dialect)
740  return emitOpError("unknown dialect or dialect not loaded: ")
741  << getDialectName();
742  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
743  if (!iface)
744  return emitOpError(
745  "dialect does not implement ConvertToLLVMPatternInterface or "
746  "extension was not loaded: ")
747  << getDialectName();
748  return success();
749 }
750 
751 //===----------------------------------------------------------------------===//
752 // ApplyLoopInvariantCodeMotionOp
753 //===----------------------------------------------------------------------===//
754 
756 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
757  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
759  transform::TransformState &state) {
760  // Currently, LICM does not remove operations, so we don't need tracking.
761  // If this ever changes, add a LICM entry point that takes a rewriter.
762  moveLoopInvariantCode(target);
764 }
765 
766 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
767  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
768  transform::onlyReadsHandle(getTargetMutable(), effects);
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // ApplyRegisteredPassOp
774 //===----------------------------------------------------------------------===//
775 
776 void transform::ApplyRegisteredPassOp::getEffects(
777  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
778  consumesHandle(getTargetMutable(), effects);
779  onlyReadsHandle(getDynamicOptionsMutable(), effects);
780  producesHandle(getOperation()->getOpResults(), effects);
781  modifiesPayload(effects);
782 }
783 
785 transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
787  transform::TransformState &state) {
788  // Obtain a single options-string to pass to the pass(-pipeline) from options
789  // passed in as a dictionary of keys mapping to values which are either
790  // attributes or param-operands pointing to attributes.
791  OperandRange dynamicOptions = getDynamicOptions();
792 
793  std::string options;
794  llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
795 
796  // A helper to convert an option's attribute value into a corresponding
797  // string representation, with the ability to obtain the attr(s) from a param.
798  std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
799  if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
800  // The corresponding value attribute(s) is/are passed in via a param.
801  // Obtain the param-operand via its specified index.
802  size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
803  assert(dynamicOptionIdx < dynamicOptions.size() &&
804  "the number of ParamOperandAttrs in the options DictionaryAttr"
805  "should be the same as the number of options passed as params");
806  ArrayRef<Attribute> attrsAssociatedToParam =
807  state.getParams(dynamicOptions[dynamicOptionIdx]);
808  // Recursive so as to append all attrs associated to the param.
809  llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
810  ",");
811  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
812  // Recursive so as to append all nested attrs of the array.
813  llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
814  } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
815  // Convert to unquoted string.
816  optionsStream << strAttr.getValue().str();
817  } else {
818  // For all other attributes, ask the attr to print itself (without type).
819  valueAttr.print(optionsStream, /*elideType=*/true);
820  }
821  };
822 
823  // Convert the options DictionaryAttr into a single string.
824  llvm::interleave(
825  getOptions(), optionsStream,
826  [&](auto namedAttribute) {
827  optionsStream << namedAttribute.getName().str(); // Append the key.
828  optionsStream << "="; // And the key-value separator.
829  appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
830  },
831  " ");
832  optionsStream.flush();
833 
834  // Get pass or pass pipeline from registry.
835  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
836  if (!info)
837  info = PassInfo::lookup(getPassName());
838  if (!info)
839  return emitDefiniteFailure()
840  << "unknown pass or pass pipeline: " << getPassName();
841 
842  // Create pass manager and add the pass or pass pipeline.
843  PassManager pm(getContext());
844  if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
845  emitError(msg);
846  return failure();
847  }))) {
848  return emitDefiniteFailure()
849  << "failed to add pass or pass pipeline to pipeline: "
850  << getPassName();
851  }
852 
853  auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
854  for (Operation *target : targets) {
855  // Make sure that this transform is not applied to itself. Modifying the
856  // transform IR while it is being interpreted is generally dangerous. Even
857  // more so when applying passes because they may perform a wide range of IR
858  // modifications.
859  DiagnosedSilenceableFailure payloadCheck =
861  if (!payloadCheck.succeeded())
862  return payloadCheck;
863 
864  // Run the pass or pass pipeline on the current target operation.
865  if (failed(pm.run(target))) {
866  auto diag = emitSilenceableError() << "pass pipeline failed";
867  diag.attachNote(target->getLoc()) << "target op";
868  return diag;
869  }
870  }
871 
872  // The applied pass will have directly modified the payload IR(s).
873  results.set(llvm::cast<OpResult>(getResult()), targets);
875 }
876 
878  OpAsmParser &parser, DictionaryAttr &options,
879  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
880  // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
881  SmallVector<NamedAttribute> keyValuePairs;
882  size_t dynamicOptionsIdx = 0;
883 
884  // Helper for allowing parsing of option values which can be of the form:
885  // - a normal attribute
886  // - an operand (which would be converted to an attr referring to the operand)
887  // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
888  std::function<ParseResult(Attribute &)> parseValue =
889  [&](Attribute &valueAttr) -> ParseResult {
890  // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
891  if (succeeded(parser.parseOptionalLSquare())) {
893 
894  // Recursively parse the array's elements, which might be operands.
895  if (parser.parseCommaSeparatedList(
897  [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
898  " in options dictionary") ||
899  parser.parseRSquare())
900  return failure(); // NB: Attempted parse should've output error message.
901 
902  valueAttr = ArrayAttr::get(parser.getContext(), attrs);
903 
904  return success();
905  }
906 
907  // Parse the value, which can be either an attribute or an operand.
908  OptionalParseResult parsedValueAttr =
909  parser.parseOptionalAttribute(valueAttr);
910  if (!parsedValueAttr.has_value()) {
912  ParseResult parsedOperand = parser.parseOperand(operand);
913  if (failed(parsedOperand))
914  return failure(); // NB: Attempted parse should've output error message.
915  // To make use of the operand, we need to store it in the options dict.
916  // As SSA-values cannot occur in attributes, what we do instead is store
917  // an attribute in its place that contains the index of the param-operand,
918  // so that an attr-value associated to the param can be resolved later on.
919  dynamicOptions.push_back(operand);
920  auto wrappedIndex = IntegerAttr::get(
921  IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
922  valueAttr =
923  transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
924  } else if (failed(parsedValueAttr.value())) {
925  return failure(); // NB: Attempted parse should have output error message.
926  } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
927  return parser.emitError(parser.getCurrentLocation())
928  << "the param_operand attribute is a marker reserved for "
929  << "indicating a value will be passed via params and is only used "
930  << "in the generic print format";
931  }
932 
933  return success();
934  };
935 
936  // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
937  // string and `value` looks like either an attribute or an operand-in-an-attr.
938  std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
939  std::string key;
940  Attribute valueAttr;
941 
942  if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
943  return parser.emitError(parser.getCurrentLocation())
944  << "expected key to either be an identifier or a string";
945 
946  if (failed(parser.parseEqual()))
947  return parser.emitError(parser.getCurrentLocation())
948  << "expected '=' after key in key-value pair";
949 
950  if (failed(parseValue(valueAttr)))
951  return parser.emitError(parser.getCurrentLocation())
952  << "expected a valid attribute or operand as value associated "
953  << "to key '" << key << "'";
954 
955  keyValuePairs.push_back(NamedAttribute(key, valueAttr));
956 
957  return success();
958  };
959 
962  " in options dictionary"))
963  return failure(); // NB: Attempted parse should have output error message.
964 
965  if (DictionaryAttr::findDuplicate(
966  keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
967  .has_value())
968  return parser.emitError(parser.getCurrentLocation())
969  << "duplicate keys found in options dictionary";
970 
971  options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
972 
973  return success();
974 }
975 
977  Operation *op,
978  DictionaryAttr options,
979  ValueRange dynamicOptions) {
980  if (options.empty())
981  return;
982 
983  std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
984  if (auto paramOperandAttr =
985  dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
986  // Resolve index of param-operand to its actual SSA-value and print that.
987  printer.printOperand(
988  dynamicOptions[paramOperandAttr.getIndex().getInt()]);
989  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
990  // This case is so that ArrayAttr-contained operands are pretty-printed.
991  printer << "[";
992  llvm::interleaveComma(arrayAttr, printer, printOptionValue);
993  printer << "]";
994  } else {
995  printer.printAttribute(valueAttr);
996  }
997  };
998 
999  printer << "{";
1000  llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
1001  printer << namedAttribute.getName();
1002  printer << " = ";
1003  printOptionValue(namedAttribute.getValue());
1004  });
1005  printer << "}";
1006 }
1007 
1009  // Check that there is a one-to-one correspondence between param operands
1010  // and references to dynamic options in the options dictionary.
1011 
1012  auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
1013 
1014  // Helper for option values to mark seen operands as having been seen (once).
1015  std::function<LogicalResult(Attribute)> checkOptionValue =
1016  [&](Attribute valueAttr) -> LogicalResult {
1017  if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1018  size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1019  if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
1020  return emitOpError()
1021  << "dynamic option index " << dynamicOptionIdx
1022  << " is out of bounds for the number of dynamic options: "
1023  << dynamicOptions.size();
1024  if (dynamicOptions[dynamicOptionIdx] == nullptr)
1025  return emitOpError() << "dynamic option index " << dynamicOptionIdx
1026  << " is already used in options";
1027  dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1028  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1029  // Recurse into ArrayAttrs as they may contain references to operands.
1030  for (auto eltAttr : arrayAttr)
1031  if (failed(checkOptionValue(eltAttr)))
1032  return failure();
1033  }
1034  return success();
1035  };
1036 
1037  for (NamedAttribute namedAttr : getOptions())
1038  if (failed(checkOptionValue(namedAttr.getValue())))
1039  return failure();
1040 
1041  // All dynamicOptions-params seen in the dict will have been set to null.
1042  for (Value dynamicOption : dynamicOptions)
1043  if (dynamicOption)
1044  return emitOpError() << "a param operand does not have a corresponding "
1045  << "param_operand attr in the options dict";
1046 
1047  return success();
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // CastOp
1052 //===----------------------------------------------------------------------===//
1053 
1055 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1056  Operation *target, ApplyToEachResultList &results,
1057  transform::TransformState &state) {
1058  results.push_back(target);
1060 }
1061 
1062 void transform::CastOp::getEffects(
1063  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1064  onlyReadsPayload(effects);
1065  onlyReadsHandle(getInputMutable(), effects);
1066  producesHandle(getOperation()->getOpResults(), effects);
1067 }
1068 
1069 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1070  assert(inputs.size() == 1 && "expected one input");
1071  assert(outputs.size() == 1 && "expected one output");
1072  return llvm::all_of(
1073  std::initializer_list<Type>{inputs.front(), outputs.front()},
1074  llvm::IsaPred<transform::TransformHandleTypeInterface>);
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // CollectMatchingOp
1079 //===----------------------------------------------------------------------===//
1080 
1081 /// Applies matcher operations from the given `block` using
1082 /// `blockArgumentMapping` to initialize block arguments. Updates `state`
1083 /// accordingly. If any of the matcher produces a silenceable failure, discards
1084 /// it (printing the content to the debug output stream) and returns failure. If
1085 /// any of the matchers produces a definite failure, reports it and returns
1086 /// failure. If all matchers in the block succeed, populates `mappings` with the
1087 /// payload entities associated with the block terminator operands. Note that
1088 /// `mappings` will be cleared before that.
1091  ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1093  SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
1094  assert(block.getParent() && "cannot match using a detached block");
1095  auto matchScope = state.make_region_scope(*block.getParent());
1096  if (failed(
1097  state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
1099 
1100  for (Operation &match : block.without_terminator()) {
1101  if (!isa<transform::MatchOpInterface>(match)) {
1102  return emitDefiniteFailure(match.getLoc())
1103  << "expected operations in the match part to "
1104  "implement MatchOpInterface";
1105  }
1107  state.applyTransform(cast<transform::TransformOpInterface>(match));
1108  if (diag.succeeded())
1109  continue;
1110 
1111  return diag;
1112  }
1113 
1114  // Remember the values mapped to the terminator operands so we can
1115  // forward them to the action.
1116  ValueRange yieldedValues = block.getTerminator()->getOperands();
1117  // Our contract with the caller is that the mappings will contain only the
1118  // newly mapped values, clear the rest.
1119  mappings.clear();
1120  transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1122 }
1123 
1124 /// Returns `true` if both types implement one of the interfaces provided as
1125 /// template parameters.
1126 template <typename... Tys>
1127 static bool implementSameInterface(Type t1, Type t2) {
1128  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1129 }
1130 
1131 /// Returns `true` if both types implement one of the transform dialect
1132 /// interfaces.
1134  return implementSameInterface<transform::TransformHandleTypeInterface,
1135  transform::TransformParamTypeInterface,
1136  transform::TransformValueHandleTypeInterface>(
1137  t1, t2);
1138 }
1139 
1140 //===----------------------------------------------------------------------===//
1141 // CollectMatchingOp
1142 //===----------------------------------------------------------------------===//
1143 
1145 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1146  transform::TransformResults &results,
1147  transform::TransformState &state) {
1148  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
1149  getOperation(), getMatcher());
1150  if (matcher.isExternal()) {
1151  return emitDefiniteFailure()
1152  << "unresolved external symbol " << getMatcher();
1153  }
1154 
1155  SmallVector<SmallVector<MappedValue>, 2> rawResults;
1156  rawResults.resize(getOperation()->getNumResults());
1157  std::optional<DiagnosedSilenceableFailure> maybeFailure;
1158  for (Operation *root : state.getPayloadOps(getRoot())) {
1159  WalkResult walkResult = root->walk([&](Operation *op) {
1160  DEBUG_MATCHER({
1161  DBGS_MATCHER() << "matching ";
1162  op->print(llvm::dbgs(),
1163  OpPrintingFlags().assumeVerified().skipRegions());
1164  llvm::dbgs() << " @" << op << "\n";
1165  });
1166 
1167  // Try matching.
1169  SmallVector<transform::MappedValue> inputMapping({op});
1171  matcher.getFunctionBody().front(),
1172  ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1173  mappings);
1174  if (diag.isDefiniteFailure())
1175  return WalkResult::interrupt();
1176  if (diag.isSilenceableFailure()) {
1177  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1178  << " failed: " << diag.getMessage());
1179  return WalkResult::advance();
1180  }
1181 
1182  // If succeeded, collect results.
1183  for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1184  if (mapping.size() != 1) {
1185  maybeFailure.emplace(emitSilenceableError()
1186  << "result #" << i << ", associated with "
1187  << mapping.size()
1188  << " payload objects, expected 1");
1189  return WalkResult::interrupt();
1190  }
1191  rawResults[i].push_back(mapping[0]);
1192  }
1193  return WalkResult::advance();
1194  });
1195  if (walkResult.wasInterrupted())
1196  return std::move(*maybeFailure);
1197  assert(!maybeFailure && "failure set but the walk was not interrupted");
1198 
1199  for (auto &&[opResult, rawResult] :
1200  llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1201  results.setMappedValues(opResult, rawResult);
1202  }
1203  }
1205 }
1206 
1207 void transform::CollectMatchingOp::getEffects(
1208  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1209  onlyReadsHandle(getRootMutable(), effects);
1210  producesHandle(getOperation()->getOpResults(), effects);
1211  onlyReadsPayload(effects);
1212 }
1213 
1214 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1215  SymbolTableCollection &symbolTable) {
1216  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1217  symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1218  if (!matcherSymbol ||
1219  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1220  return emitError() << "unresolved matcher symbol " << getMatcher();
1221 
1222  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1223  if (argumentTypes.size() != 1 ||
1224  !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1225  return emitError()
1226  << "expected the matcher to take one operation handle argument";
1227  }
1228  if (!matcherSymbol.getArgAttr(
1229  0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1230  return emitError() << "expected the matcher argument to be marked readonly";
1231  }
1232 
1233  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1234  if (resultTypes.size() != getOperation()->getNumResults()) {
1235  return emitError()
1236  << "expected the matcher to yield as many values as op has results ("
1237  << getOperation()->getNumResults() << "), got "
1238  << resultTypes.size();
1239  }
1240 
1241  for (auto &&[i, matcherType, resultType] :
1242  llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1243  if (implementSameTransformInterface(matcherType, resultType))
1244  continue;
1245 
1246  return emitError()
1247  << "mismatching type interfaces for matcher result and op result #"
1248  << i;
1249  }
1250 
1251  return success();
1252 }
1253 
1254 //===----------------------------------------------------------------------===//
1255 // ForeachMatchOp
1256 //===----------------------------------------------------------------------===//
1257 
1258 // This is fine because nothing is actually consumed by this op.
1259 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1260 
1262 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1263  transform::TransformResults &results,
1264  transform::TransformState &state) {
1266  matchActionPairs;
1267  matchActionPairs.reserve(getMatchers().size());
1268  SymbolTableCollection symbolTable;
1269  for (auto &&[matcher, action] :
1270  llvm::zip_equal(getMatchers(), getActions())) {
1271  auto matcherSymbol =
1272  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1273  getOperation(), cast<SymbolRefAttr>(matcher));
1274  auto actionSymbol =
1275  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1276  getOperation(), cast<SymbolRefAttr>(action));
1277  assert(matcherSymbol && actionSymbol &&
1278  "unresolved symbols not caught by the verifier");
1279 
1280  if (matcherSymbol.isExternal())
1281  return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1282  if (actionSymbol.isExternal())
1283  return emitDefiniteFailure() << "unresolved external symbol " << action;
1284 
1285  matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1286  }
1287 
1288  DiagnosedSilenceableFailure overallDiag =
1290 
1291  SmallVector<SmallVector<MappedValue>> matchInputMapping;
1292  SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1293  SmallVector<SmallVector<MappedValue>> actionResultMapping;
1294  // Explicitly add the mapping for the first block argument (the op being
1295  // matched).
1296  matchInputMapping.emplace_back();
1297  transform::detail::prepareValueMappings(matchInputMapping,
1298  getForwardedInputs(), state);
1299  SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1300  actionResultMapping.resize(getForwardedOutputs().size());
1301 
1302  for (Operation *root : state.getPayloadOps(getRoot())) {
1303  WalkResult walkResult = root->walk([&](Operation *op) {
1304  // If getRestrictRoot is not present, skip over the root op itself so we
1305  // don't invalidate it.
1306  if (!getRestrictRoot() && op == root)
1307  return WalkResult::advance();
1308 
1309  DEBUG_MATCHER({
1310  DBGS_MATCHER() << "matching ";
1311  op->print(llvm::dbgs(),
1312  OpPrintingFlags().assumeVerified().skipRegions());
1313  llvm::dbgs() << " @" << op << "\n";
1314  });
1315 
1316  firstMatchArgument.clear();
1317  firstMatchArgument.push_back(op);
1318 
1319  // Try all the match/action pairs until the first successful match.
1320  for (auto [matcher, action] : matchActionPairs) {
1322  matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1323  state, matchOutputMapping);
1324  if (diag.isDefiniteFailure())
1325  return WalkResult::interrupt();
1326  if (diag.isSilenceableFailure()) {
1327  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1328  << " failed: " << diag.getMessage());
1329  continue;
1330  }
1331 
1332  auto scope = state.make_region_scope(action.getFunctionBody());
1333  if (failed(state.mapBlockArguments(
1334  action.getFunctionBody().front().getArguments(),
1335  matchOutputMapping))) {
1336  return WalkResult::interrupt();
1337  }
1338 
1339  for (Operation &transform :
1340  action.getFunctionBody().front().without_terminator()) {
1342  state.applyTransform(cast<TransformOpInterface>(transform));
1343  if (result.isDefiniteFailure())
1344  return WalkResult::interrupt();
1345  if (result.isSilenceableFailure()) {
1346  if (overallDiag.succeeded()) {
1347  overallDiag = emitSilenceableError() << "actions failed";
1348  }
1349  overallDiag.attachNote(action->getLoc())
1350  << "failed action: " << result.getMessage();
1351  overallDiag.attachNote(op->getLoc())
1352  << "when applied to this matching payload";
1353  (void)result.silence();
1354  continue;
1355  }
1356  }
1357  if (failed(detail::appendValueMappings(
1358  MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1359  action.getFunctionBody().front().getTerminator()->getOperands(),
1360  state, getFlattenResults()))) {
1362  << "action @" << action.getName()
1363  << " has results associated with multiple payload entities, "
1364  "but flattening was not requested";
1365  return WalkResult::interrupt();
1366  }
1367  break;
1368  }
1369  return WalkResult::advance();
1370  });
1371  if (walkResult.wasInterrupted())
1373  }
1374 
1375  // The root operation should not have been affected, so we can just reassign
1376  // the payload to the result. Note that we need to consume the root handle to
1377  // make sure any handles to operations inside, that could have been affected
1378  // by actions, are invalidated.
1379  results.set(llvm::cast<OpResult>(getUpdated()),
1380  state.getPayloadOps(getRoot()));
1381  for (auto &&[result, mapping] :
1382  llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1383  results.setMappedValues(result, mapping);
1384  }
1385  return overallDiag;
1386 }
1387 
1388 void transform::ForeachMatchOp::getAsmResultNames(
1389  OpAsmSetValueNameFn setNameFn) {
1390  setNameFn(getUpdated(), "updated_root");
1391  for (Value v : getForwardedOutputs()) {
1392  setNameFn(v, "yielded");
1393  }
1394 }
1395 
1396 void transform::ForeachMatchOp::getEffects(
1397  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1398  // Bail if invalid.
1399  if (getOperation()->getNumOperands() < 1 ||
1400  getOperation()->getNumResults() < 1) {
1401  return modifiesPayload(effects);
1402  }
1403 
1404  consumesHandle(getRootMutable(), effects);
1405  onlyReadsHandle(getForwardedInputsMutable(), effects);
1406  producesHandle(getOperation()->getOpResults(), effects);
1407  modifiesPayload(effects);
1408 }
1409 
1410 /// Parses the comma-separated list of symbol reference pairs of the format
1411 /// `@matcher -> @action`.
1412 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1413  ArrayAttr &matchers,
1414  ArrayAttr &actions) {
1415  StringAttr matcher;
1416  StringAttr action;
1417  SmallVector<Attribute> matcherList;
1418  SmallVector<Attribute> actionList;
1419  do {
1420  if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1421  parser.parseSymbolName(action)) {
1422  return failure();
1423  }
1424  matcherList.push_back(SymbolRefAttr::get(matcher));
1425  actionList.push_back(SymbolRefAttr::get(action));
1426  } while (parser.parseOptionalComma().succeeded());
1427 
1428  matchers = parser.getBuilder().getArrayAttr(matcherList);
1429  actions = parser.getBuilder().getArrayAttr(actionList);
1430  return success();
1431 }
1432 
1433 /// Prints the comma-separated list of symbol reference pairs of the format
1434 /// `@matcher -> @action`.
1436  ArrayAttr matchers, ArrayAttr actions) {
1437  printer.increaseIndent();
1438  printer.increaseIndent();
1439  for (auto &&[matcher, action, idx] : llvm::zip_equal(
1440  matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1441  printer.printNewline();
1442  printer << cast<SymbolRefAttr>(matcher) << " -> "
1443  << cast<SymbolRefAttr>(action);
1444  if (idx != matchers.size() - 1)
1445  printer << ", ";
1446  }
1447  printer.decreaseIndent();
1448  printer.decreaseIndent();
1449 }
1450 
1451 LogicalResult transform::ForeachMatchOp::verify() {
1452  if (getMatchers().size() != getActions().size())
1453  return emitOpError() << "expected the same number of matchers and actions";
1454  if (getMatchers().empty())
1455  return emitOpError() << "expected at least one match/action pair";
1456 
1457  llvm::SmallPtrSet<Attribute, 8> matcherNames;
1458  for (Attribute name : getMatchers()) {
1459  if (matcherNames.insert(name).second)
1460  continue;
1461  emitWarning() << "matcher " << name
1462  << " is used more than once, only the first match will apply";
1463  }
1464 
1465  return success();
1466 }
1467 
1468 /// Checks that the attributes of the function-like operation have correct
1469 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
1470 /// annotations being present even if they can be inferred from the body.
1472 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1473  bool alsoVerifyInternal = false) {
1474  auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1475  llvm::SmallDenseSet<unsigned> consumedArguments;
1476  if (!op.isExternal()) {
1477  transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1478  consumedArguments);
1479  }
1480  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1481  bool isConsumed =
1482  op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1483  nullptr;
1484  bool isReadOnly =
1485  op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1486  nullptr;
1487  if (isConsumed && isReadOnly) {
1488  return transformOp.emitSilenceableError()
1489  << "argument #" << i << " cannot be both readonly and consumed";
1490  }
1491  if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1492  return transformOp.emitSilenceableError()
1493  << "must provide consumed/readonly status for arguments of "
1494  "external or called ops";
1495  }
1496  if (op.isExternal())
1497  continue;
1498 
1499  if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1500  return transformOp.emitSilenceableError()
1501  << "argument #" << i
1502  << " is consumed in the body but is not marked as such";
1503  }
1504  if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1505  // Cannot use op.emitWarning() here as it would attempt to verify the op
1506  // before printing, resulting in infinite recursion.
1507  emitWarning(op->getLoc())
1508  << "op argument #" << i
1509  << " is not consumed in the body but is marked as consumed";
1510  }
1511  }
1513 }
1514 
1515 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1516  SymbolTableCollection &symbolTable) {
1517  assert(getMatchers().size() == getActions().size());
1518  auto consumedAttr =
1519  StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1520  for (auto &&[matcher, action] :
1521  llvm::zip_equal(getMatchers(), getActions())) {
1522  // Presence and typing.
1523  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1524  symbolTable.lookupNearestSymbolFrom(getOperation(),
1525  cast<SymbolRefAttr>(matcher)));
1526  auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1527  symbolTable.lookupNearestSymbolFrom(getOperation(),
1528  cast<SymbolRefAttr>(action)));
1529  if (!matcherSymbol ||
1530  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1531  return emitError() << "unresolved matcher symbol " << matcher;
1532  if (!actionSymbol ||
1533  !isa<TransformOpInterface>(actionSymbol.getOperation()))
1534  return emitError() << "unresolved action symbol " << action;
1535 
1536  if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1537  /*emitWarnings=*/false,
1538  /*alsoVerifyInternal=*/true)
1539  .checkAndReport())) {
1540  return failure();
1541  }
1542  if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
1543  /*emitWarnings=*/false,
1544  /*alsoVerifyInternal=*/true)
1545  .checkAndReport())) {
1546  return failure();
1547  }
1548 
1549  // Input -> matcher forwarding.
1550  TypeRange operandTypes = getOperandTypes();
1551  TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1552  if (operandTypes.size() != matcherArguments.size()) {
1554  emitError() << "the number of operands (" << operandTypes.size()
1555  << ") doesn't match the number of matcher arguments ("
1556  << matcherArguments.size() << ") for " << matcher;
1557  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1558  return diag;
1559  }
1560  for (auto &&[i, operand, argument] :
1561  llvm::enumerate(operandTypes, matcherArguments)) {
1562  if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1564  emitOpError()
1565  << "does not expect matcher symbol to consume its operand #" << i;
1566  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1567  return diag;
1568  }
1569 
1570  if (implementSameTransformInterface(operand, argument))
1571  continue;
1572 
1574  emitError()
1575  << "mismatching type interfaces for operand and matcher argument #"
1576  << i << " of matcher " << matcher;
1577  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1578  return diag;
1579  }
1580 
1581  // Matcher -> action forwarding.
1582  TypeRange matcherResults = matcherSymbol.getResultTypes();
1583  TypeRange actionArguments = actionSymbol.getArgumentTypes();
1584  if (matcherResults.size() != actionArguments.size()) {
1585  return emitError() << "mismatching number of matcher results and "
1586  "action arguments between "
1587  << matcher << " (" << matcherResults.size() << ") and "
1588  << action << " (" << actionArguments.size() << ")";
1589  }
1590  for (auto &&[i, matcherType, actionType] :
1591  llvm::enumerate(matcherResults, actionArguments)) {
1592  if (implementSameTransformInterface(matcherType, actionType))
1593  continue;
1594 
1595  return emitError() << "mismatching type interfaces for matcher result "
1596  "and action argument #"
1597  << i << "of matcher " << matcher << " and action "
1598  << action;
1599  }
1600 
1601  // Action -> result forwarding.
1602  TypeRange actionResults = actionSymbol.getResultTypes();
1603  auto resultTypes = TypeRange(getResultTypes()).drop_front();
1604  if (actionResults.size() != resultTypes.size()) {
1606  emitError() << "the number of action results ("
1607  << actionResults.size() << ") for " << action
1608  << " doesn't match the number of extra op results ("
1609  << resultTypes.size() << ")";
1610  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1611  return diag;
1612  }
1613  for (auto &&[i, resultType, actionType] :
1614  llvm::enumerate(resultTypes, actionResults)) {
1615  if (implementSameTransformInterface(resultType, actionType))
1616  continue;
1617 
1619  emitError() << "mismatching type interfaces for action result #" << i
1620  << " of action " << action << " and op result";
1621  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1622  return diag;
1623  }
1624  }
1625  return success();
1626 }
1627 
1628 //===----------------------------------------------------------------------===//
1629 // ForeachOp
1630 //===----------------------------------------------------------------------===//
1631 
1633 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1634  transform::TransformResults &results,
1635  transform::TransformState &state) {
1636  // We store the payloads before executing the body as ops may be removed from
1637  // the mapping by the TrackingRewriter while iteration is in progress.
1639  detail::prepareValueMappings(payloads, getTargets(), state);
1640  size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1641  bool withZipShortest = getWithZipShortest();
1642 
1643  // In case of `zip_shortest`, set the number of iterations to the
1644  // smallest payload in the targets.
1645  if (withZipShortest) {
1646  numIterations =
1647  llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
1648  const SmallVector<MappedValue> &B) {
1649  return A.size() < B.size();
1650  })->size();
1651 
1652  for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1653  payloads[argIdx].resize(numIterations);
1654  }
1655 
1656  // As we will be "zipping" over them, check all payloads have the same size.
1657  // `zip_shortest` adjusts all payloads to the same size, so skip this check
1658  // when true.
1659  for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1660  argIdx++) {
1661  if (payloads[argIdx].size() != numIterations) {
1662  return emitSilenceableError()
1663  << "prior targets' payload size (" << numIterations
1664  << ") differs from payload size (" << payloads[argIdx].size()
1665  << ") of target " << getTargets()[argIdx];
1666  }
1667  }
1668 
1669  // Start iterating, indexing into payloads to obtain the right arguments to
1670  // call the body with - each slice of payloads at the same argument index
1671  // corresponding to a tuple to use as the body's block arguments.
1672  ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1673  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1674  for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1675  auto scope = state.make_region_scope(getBody());
1676  // Set up arguments to the region's block.
1677  for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1678  MappedValue argument = payloads[argIdx][iterIdx];
1679  // Note that each blockArg's handle gets associated with just a single
1680  // element from the corresponding target's payload.
1681  if (failed(state.mapBlockArgument(blockArg, {argument})))
1683  }
1684 
1685  // Execute loop body.
1686  for (Operation &transform : getBody().front().without_terminator()) {
1687  DiagnosedSilenceableFailure result = state.applyTransform(
1688  llvm::cast<transform::TransformOpInterface>(transform));
1689  if (!result.succeeded())
1690  return result;
1691  }
1692 
1693  // Append yielded payloads to corresponding results from prior iterations.
1694  OperandRange yieldOperands = getYieldOp().getOperands();
1695  for (auto &&[result, yieldOperand, resTuple] :
1696  llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1697  // NB: each iteration we add any number of ops/vals/params to a result.
1698  if (isa<TransformHandleTypeInterface>(result.getType()))
1699  llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1700  else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1701  llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1702  else if (isa<TransformParamTypeInterface>(result.getType()))
1703  llvm::append_range(resTuple, state.getParams(yieldOperand));
1704  else
1705  assert(false && "unhandled handle type");
1706  }
1707 
1708  // Associate the accumulated result payloads to the op's actual results.
1709  for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1710  results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1711 
1713 }
1714 
1715 void transform::ForeachOp::getEffects(
1716  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1717  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1718  // arity errors, this method might get called before/in absence of `verify()`.
1719  for (auto &&[target, blockArg] :
1720  llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1721  BlockArgument blockArgument = blockArg;
1722  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1723  return isHandleConsumed(blockArgument,
1724  cast<TransformOpInterface>(&op));
1725  })) {
1726  consumesHandle(target, effects);
1727  } else {
1728  onlyReadsHandle(target, effects);
1729  }
1730  }
1731 
1732  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1733  return doesModifyPayload(cast<TransformOpInterface>(&op));
1734  })) {
1735  modifiesPayload(effects);
1736  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1737  return doesReadPayload(cast<TransformOpInterface>(&op));
1738  })) {
1739  onlyReadsPayload(effects);
1740  }
1741 
1742  producesHandle(getOperation()->getOpResults(), effects);
1743 }
1744 
1745 void transform::ForeachOp::getSuccessorRegions(
1746  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1747  Region *bodyRegion = &getBody();
1748  if (point.isParent()) {
1749  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1750  return;
1751  }
1752 
1753  // Branch back to the region or the parent.
1754  assert(point == getBody() && "unexpected region index");
1755  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1756  regions.emplace_back();
1757 }
1758 
1760 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1761  // Each block argument handle is mapped to a subset (one op to be precise)
1762  // of the payload of the corresponding `targets` operand of ForeachOp.
1763  assert(point == getBody() && "unexpected region index");
1764  return getOperation()->getOperands();
1765 }
1766 
1767 transform::YieldOp transform::ForeachOp::getYieldOp() {
1768  return cast<transform::YieldOp>(getBody().front().getTerminator());
1769 }
1770 
1771 LogicalResult transform::ForeachOp::verify() {
1772  for (auto [targetOpt, bodyArgOpt] :
1773  llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1774  if (!targetOpt || !bodyArgOpt)
1775  return emitOpError() << "expects the same number of targets as the body "
1776  "has block arguments";
1777  if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1778  return emitOpError(
1779  "expects co-indexed targets and the body's "
1780  "block arguments to have the same op/value/param type");
1781  }
1782 
1783  for (auto [resultOpt, yieldOperandOpt] :
1784  llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1785  if (!resultOpt || !yieldOperandOpt)
1786  return emitOpError() << "expects the same number of results as the "
1787  "yield terminator has operands";
1788  if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1789  return emitOpError("expects co-indexed results and yield "
1790  "operands to have the same op/value/param type");
1791  }
1792 
1793  return success();
1794 }
1795 
1796 //===----------------------------------------------------------------------===//
1797 // GetParentOp
1798 //===----------------------------------------------------------------------===//
1799 
1801 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1802  transform::TransformResults &results,
1803  transform::TransformState &state) {
1804  SmallVector<Operation *> parents;
1805  DenseSet<Operation *> resultSet;
1806  for (Operation *target : state.getPayloadOps(getTarget())) {
1807  Operation *parent = target;
1808  for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1809  parent = parent->getParentOp();
1810  while (parent) {
1811  bool checkIsolatedFromAbove =
1812  !getIsolatedFromAbove() ||
1814  bool checkOpName = !getOpName().has_value() ||
1815  parent->getName().getStringRef() == *getOpName();
1816  if (checkIsolatedFromAbove && checkOpName)
1817  break;
1818  parent = parent->getParentOp();
1819  }
1820  if (!parent) {
1821  if (getAllowEmptyResults()) {
1822  results.set(llvm::cast<OpResult>(getResult()), parents);
1824  }
1826  emitSilenceableError()
1827  << "could not find a parent op that matches all requirements";
1828  diag.attachNote(target->getLoc()) << "target op";
1829  return diag;
1830  }
1831  }
1832  if (getDeduplicate()) {
1833  if (resultSet.insert(parent).second)
1834  parents.push_back(parent);
1835  } else {
1836  parents.push_back(parent);
1837  }
1838  }
1839  results.set(llvm::cast<OpResult>(getResult()), parents);
1841 }
1842 
1843 //===----------------------------------------------------------------------===//
1844 // GetConsumersOfResult
1845 //===----------------------------------------------------------------------===//
1846 
1848 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1849  transform::TransformResults &results,
1850  transform::TransformState &state) {
1851  int64_t resultNumber = getResultNumber();
1852  auto payloadOps = state.getPayloadOps(getTarget());
1853  if (std::empty(payloadOps)) {
1854  results.set(cast<OpResult>(getResult()), {});
1856  }
1857  if (!llvm::hasSingleElement(payloadOps))
1858  return emitDefiniteFailure()
1859  << "handle must be mapped to exactly one payload op";
1860 
1861  Operation *target = *payloadOps.begin();
1862  if (target->getNumResults() <= resultNumber)
1863  return emitDefiniteFailure() << "result number overflow";
1864  results.set(llvm::cast<OpResult>(getResult()),
1865  llvm::to_vector(target->getResult(resultNumber).getUsers()));
1867 }
1868 
1869 //===----------------------------------------------------------------------===//
1870 // GetDefiningOp
1871 //===----------------------------------------------------------------------===//
1872 
1874 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1875  transform::TransformResults &results,
1876  transform::TransformState &state) {
1877  SmallVector<Operation *> definingOps;
1878  for (Value v : state.getPayloadValues(getTarget())) {
1879  if (llvm::isa<BlockArgument>(v)) {
1881  emitSilenceableError() << "cannot get defining op of block argument";
1882  diag.attachNote(v.getLoc()) << "target value";
1883  return diag;
1884  }
1885  definingOps.push_back(v.getDefiningOp());
1886  }
1887  results.set(llvm::cast<OpResult>(getResult()), definingOps);
1889 }
1890 
1891 //===----------------------------------------------------------------------===//
1892 // GetProducerOfOperand
1893 //===----------------------------------------------------------------------===//
1894 
1896 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1897  transform::TransformResults &results,
1898  transform::TransformState &state) {
1899  int64_t operandNumber = getOperandNumber();
1900  SmallVector<Operation *> producers;
1901  for (Operation *target : state.getPayloadOps(getTarget())) {
1902  Operation *producer =
1903  target->getNumOperands() <= operandNumber
1904  ? nullptr
1905  : target->getOperand(operandNumber).getDefiningOp();
1906  if (!producer) {
1908  emitSilenceableError()
1909  << "could not find a producer for operand number: " << operandNumber
1910  << " of " << *target;
1911  diag.attachNote(target->getLoc()) << "target op";
1912  return diag;
1913  }
1914  producers.push_back(producer);
1915  }
1916  results.set(llvm::cast<OpResult>(getResult()), producers);
1918 }
1919 
1920 //===----------------------------------------------------------------------===//
1921 // GetOperandOp
1922 //===----------------------------------------------------------------------===//
1923 
1925 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1926  transform::TransformResults &results,
1927  transform::TransformState &state) {
1928  SmallVector<Value> operands;
1929  for (Operation *target : state.getPayloadOps(getTarget())) {
1930  SmallVector<int64_t> operandPositions;
1932  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1933  target->getNumOperands(), operandPositions);
1934  if (diag.isSilenceableFailure()) {
1935  diag.attachNote(target->getLoc())
1936  << "while considering positions of this payload operation";
1937  return diag;
1938  }
1939  llvm::append_range(operands,
1940  llvm::map_range(operandPositions, [&](int64_t pos) {
1941  return target->getOperand(pos);
1942  }));
1943  }
1944  results.setValues(cast<OpResult>(getResult()), operands);
1946 }
1947 
1948 LogicalResult transform::GetOperandOp::verify() {
1949  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1950  getIsInverted(), getIsAll());
1951 }
1952 
1953 //===----------------------------------------------------------------------===//
1954 // GetResultOp
1955 //===----------------------------------------------------------------------===//
1956 
1958 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1959  transform::TransformResults &results,
1960  transform::TransformState &state) {
1961  SmallVector<Value> opResults;
1962  for (Operation *target : state.getPayloadOps(getTarget())) {
1963  SmallVector<int64_t> resultPositions;
1965  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1966  target->getNumResults(), resultPositions);
1967  if (diag.isSilenceableFailure()) {
1968  diag.attachNote(target->getLoc())
1969  << "while considering positions of this payload operation";
1970  return diag;
1971  }
1972  llvm::append_range(opResults,
1973  llvm::map_range(resultPositions, [&](int64_t pos) {
1974  return target->getResult(pos);
1975  }));
1976  }
1977  results.setValues(cast<OpResult>(getResult()), opResults);
1979 }
1980 
1981 LogicalResult transform::GetResultOp::verify() {
1982  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1983  getIsInverted(), getIsAll());
1984 }
1985 
1986 //===----------------------------------------------------------------------===//
1987 // GetTypeOp
1988 //===----------------------------------------------------------------------===//
1989 
1990 void transform::GetTypeOp::getEffects(
1991  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1992  onlyReadsHandle(getValueMutable(), effects);
1993  producesHandle(getOperation()->getOpResults(), effects);
1994  onlyReadsPayload(effects);
1995 }
1996 
1998 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1999  transform::TransformResults &results,
2000  transform::TransformState &state) {
2001  SmallVector<Attribute> params;
2002  for (Value value : state.getPayloadValues(getValue())) {
2003  Type type = value.getType();
2004  if (getElemental()) {
2005  if (auto shaped = dyn_cast<ShapedType>(type)) {
2006  type = shaped.getElementType();
2007  }
2008  }
2009  params.push_back(TypeAttr::get(type));
2010  }
2011  results.setParams(cast<OpResult>(getResult()), params);
2013 }
2014 
2015 //===----------------------------------------------------------------------===//
2016 // IncludeOp
2017 //===----------------------------------------------------------------------===//
2018 
2019 /// Applies the transform ops contained in `block`. Maps `results` to the same
2020 /// values as the operands of the block terminator.
2022 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
2024  transform::TransformResults &results) {
2025  // Apply the sequenced ops one by one.
2026  for (Operation &transform : block.without_terminator()) {
2028  state.applyTransform(cast<transform::TransformOpInterface>(transform));
2029  if (result.isDefiniteFailure())
2030  return result;
2031 
2032  if (result.isSilenceableFailure()) {
2033  if (mode == transform::FailurePropagationMode::Propagate) {
2034  // Propagate empty results in case of early exit.
2035  forwardEmptyOperands(&block, state, results);
2036  return result;
2037  }
2038  (void)result.silence();
2039  }
2040  }
2041 
2042  // Forward the operation mapping for values yielded from the sequence to the
2043  // values produced by the sequence op.
2044  transform::detail::forwardTerminatorOperands(&block, state, results);
2046 }
2047 
2049 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2050  transform::TransformResults &results,
2051  transform::TransformState &state) {
2052  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2053  getOperation(), getTarget());
2054  assert(callee && "unverified reference to unknown symbol");
2055 
2056  if (callee.isExternal())
2057  return emitDefiniteFailure() << "unresolved external named sequence";
2058 
2059  // Map operands to block arguments.
2061  detail::prepareValueMappings(mappings, getOperands(), state);
2062  auto scope = state.make_region_scope(callee.getBody());
2063  for (auto &&[arg, map] :
2064  llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2065  if (failed(state.mapBlockArgument(arg, map)))
2067  }
2068 
2070  callee.getBody().front(), getFailurePropagationMode(), state, results);
2071  mappings.clear();
2073  mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2074  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2075  results.setMappedValues(result, mapping);
2076  return result;
2077 }
2078 
2080 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2081 
2082 void transform::IncludeOp::getEffects(
2083  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2084  // Always mark as modifying the payload.
2085  // TODO: a mechanism to annotate effects on payload. Even when all handles are
2086  // only read, the payload may still be modified, so we currently stay on the
2087  // conservative side and always indicate modification. This may prevent some
2088  // code reordering.
2089  modifiesPayload(effects);
2090 
2091  // Results are always produced.
2092  producesHandle(getOperation()->getOpResults(), effects);
2093 
2094  // Adds default effects to operands and results. This will be added if
2095  // preconditions fail so the trait verifier doesn't complain about missing
2096  // effects and the real precondition failure is reported later on.
2097  auto defaultEffects = [&] {
2098  onlyReadsHandle(getOperation()->getOpOperands(), effects);
2099  };
2100 
2101  // Bail if the callee is unknown. This may run as part of the verification
2102  // process before we verified the validity of the callee or of this op.
2103  auto target =
2104  getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2105  if (!target)
2106  return defaultEffects();
2107  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2108  getOperation(), getTarget());
2109  if (!callee)
2110  return defaultEffects();
2111  DiagnosedSilenceableFailure earlyVerifierResult =
2112  verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
2113  if (!earlyVerifierResult.succeeded()) {
2114  (void)earlyVerifierResult.silence();
2115  return defaultEffects();
2116  }
2117 
2118  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2119  if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2120  consumesHandle(getOperation()->getOpOperand(i), effects);
2121  else
2122  onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2123  }
2124 }
2125 
2126 LogicalResult
2127 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2128  // Access through indirection and do additional checking because this may be
2129  // running before the main op verifier.
2130  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2131  if (!targetAttr)
2132  return emitOpError() << "expects a 'target' symbol reference attribute";
2133 
2134  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2135  *this, targetAttr);
2136  if (!target)
2137  return emitOpError() << "does not reference a named transform sequence";
2138 
2139  FunctionType fnType = target.getFunctionType();
2140  if (fnType.getNumInputs() != getNumOperands())
2141  return emitError("incorrect number of operands for callee");
2142 
2143  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2144  if (getOperand(i).getType() != fnType.getInput(i)) {
2145  return emitOpError("operand type mismatch: expected operand type ")
2146  << fnType.getInput(i) << ", but provided "
2147  << getOperand(i).getType() << " for operand number " << i;
2148  }
2149  }
2150 
2151  if (fnType.getNumResults() != getNumResults())
2152  return emitError("incorrect number of results for callee");
2153 
2154  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2155  Type resultType = getResult(i).getType();
2156  Type funcType = fnType.getResult(i);
2157  if (!implementSameTransformInterface(resultType, funcType)) {
2158  return emitOpError() << "type of result #" << i
2159  << " must implement the same transform dialect "
2160  "interface as the corresponding callee result";
2161  }
2162  }
2163 
2165  cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2166  /*alsoVerifyInternal=*/true)
2167  .checkAndReport();
2168 }
2169 
2170 //===----------------------------------------------------------------------===//
2171 // MatchOperationEmptyOp
2172 //===----------------------------------------------------------------------===//
2173 
2174 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2175  ::std::optional<::mlir::Operation *> maybeCurrent,
2177  if (!maybeCurrent.has_value()) {
2178  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
2180  }
2181  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
2182  return emitSilenceableError() << "operation is not empty";
2183 }
2184 
2185 //===----------------------------------------------------------------------===//
2186 // MatchOperationNameOp
2187 //===----------------------------------------------------------------------===//
2188 
2189 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2190  Operation *current, transform::TransformResults &results,
2191  transform::TransformState &state) {
2192  StringRef currentOpName = current->getName().getStringRef();
2193  for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2194  if (acceptedAttr.getValue() == currentOpName)
2196  }
2197  return emitSilenceableError() << "wrong operation name";
2198 }
2199 
2200 //===----------------------------------------------------------------------===//
2201 // MatchParamCmpIOp
2202 //===----------------------------------------------------------------------===//
2203 
2205 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2206  transform::TransformResults &results,
2207  transform::TransformState &state) {
2208  auto signedAPIntAsString = [&](const APInt &value) {
2209  std::string str;
2210  llvm::raw_string_ostream os(str);
2211  value.print(os, /*isSigned=*/true);
2212  return str;
2213  };
2214 
2215  ArrayRef<Attribute> params = state.getParams(getParam());
2216  ArrayRef<Attribute> references = state.getParams(getReference());
2217 
2218  if (params.size() != references.size()) {
2219  return emitSilenceableError()
2220  << "parameters have different payload lengths (" << params.size()
2221  << " vs " << references.size() << ")";
2222  }
2223 
2224  for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2225  auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2226  auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2227  if (!intAttr || !refAttr) {
2228  return emitDefiniteFailure()
2229  << "non-integer parameter value not expected";
2230  }
2231  if (intAttr.getType() != refAttr.getType()) {
2232  return emitDefiniteFailure()
2233  << "mismatching integer attribute types in parameter #" << i;
2234  }
2235  APInt value = intAttr.getValue();
2236  APInt refValue = refAttr.getValue();
2237 
2238  // TODO: this copy will not be necessary in C++20.
2239  int64_t position = i;
2240  auto reportError = [&](StringRef direction) {
2242  emitSilenceableError() << "expected parameter to be " << direction
2243  << " " << signedAPIntAsString(refValue)
2244  << ", got " << signedAPIntAsString(value);
2245  diag.attachNote(getParam().getLoc())
2246  << "value # " << position
2247  << " associated with the parameter defined here";
2248  return diag;
2249  };
2250 
2251  switch (getPredicate()) {
2252  case MatchCmpIPredicate::eq:
2253  if (value.eq(refValue))
2254  break;
2255  return reportError("equal to");
2256  case MatchCmpIPredicate::ne:
2257  if (value.ne(refValue))
2258  break;
2259  return reportError("not equal to");
2260  case MatchCmpIPredicate::lt:
2261  if (value.slt(refValue))
2262  break;
2263  return reportError("less than");
2264  case MatchCmpIPredicate::le:
2265  if (value.sle(refValue))
2266  break;
2267  return reportError("less than or equal to");
2268  case MatchCmpIPredicate::gt:
2269  if (value.sgt(refValue))
2270  break;
2271  return reportError("greater than");
2272  case MatchCmpIPredicate::ge:
2273  if (value.sge(refValue))
2274  break;
2275  return reportError("greater than or equal to");
2276  }
2277  }
2279 }
2280 
2281 void transform::MatchParamCmpIOp::getEffects(
2282  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2283  onlyReadsHandle(getParamMutable(), effects);
2284  onlyReadsHandle(getReferenceMutable(), effects);
2285 }
2286 
2287 //===----------------------------------------------------------------------===//
2288 // ParamConstantOp
2289 //===----------------------------------------------------------------------===//
2290 
2292 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2293  transform::TransformResults &results,
2294  transform::TransformState &state) {
2295  results.setParams(cast<OpResult>(getParam()), {getValue()});
2297 }
2298 
2299 //===----------------------------------------------------------------------===//
2300 // MergeHandlesOp
2301 //===----------------------------------------------------------------------===//
2302 
2304 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2305  transform::TransformResults &results,
2306  transform::TransformState &state) {
2307  ValueRange handles = getHandles();
2308  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2309  SmallVector<Operation *> operations;
2310  for (Value operand : handles)
2311  llvm::append_range(operations, state.getPayloadOps(operand));
2312  if (!getDeduplicate()) {
2313  results.set(llvm::cast<OpResult>(getResult()), operations);
2315  }
2316 
2317  SetVector<Operation *> uniqued(llvm::from_range, operations);
2318  results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2320  }
2321 
2322  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2323  SmallVector<Attribute> attrs;
2324  for (Value attribute : handles)
2325  llvm::append_range(attrs, state.getParams(attribute));
2326  if (!getDeduplicate()) {
2327  results.setParams(cast<OpResult>(getResult()), attrs);
2329  }
2330 
2331  SetVector<Attribute> uniqued(llvm::from_range, attrs);
2332  results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2334  }
2335 
2336  assert(
2337  llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2338  "expected value handle type");
2339  SmallVector<Value> payloadValues;
2340  for (Value value : handles)
2341  llvm::append_range(payloadValues, state.getPayloadValues(value));
2342  if (!getDeduplicate()) {
2343  results.setValues(cast<OpResult>(getResult()), payloadValues);
2345  }
2346 
2347  SetVector<Value> uniqued(llvm::from_range, payloadValues);
2348  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2350 }
2351 
2352 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2353  // Handles may be the same if deduplicating is enabled.
2354  return getDeduplicate();
2355 }
2356 
2357 void transform::MergeHandlesOp::getEffects(
2358  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2359  onlyReadsHandle(getHandlesMutable(), effects);
2360  producesHandle(getOperation()->getOpResults(), effects);
2361 
2362  // There are no effects on the Payload IR as this is only a handle
2363  // manipulation.
2364 }
2365 
2366 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2367  if (getDeduplicate() || getHandles().size() != 1)
2368  return {};
2369 
2370  // If deduplication is not required and there is only one operand, it can be
2371  // used directly instead of merging.
2372  return getHandles().front();
2373 }
2374 
2375 //===----------------------------------------------------------------------===//
2376 // NamedSequenceOp
2377 //===----------------------------------------------------------------------===//
2378 
2380 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2381  transform::TransformResults &results,
2382  transform::TransformState &state) {
2383  if (isExternal())
2384  return emitDefiniteFailure() << "unresolved external named sequence";
2385 
2386  // Map the entry block argument to the list of operations.
2387  // Note: this is the same implementation as PossibleTopLevelTransformOp but
2388  // without attaching the interface / trait since that is tailored to a
2389  // dangling top-level op that does not get "called".
2390  auto scope = state.make_region_scope(getBody());
2392  state, this->getOperation(), getBody())))
2394 
2395  return applySequenceBlock(getBody().front(),
2396  FailurePropagationMode::Propagate, state, results);
2397 }
2398 
2399 void transform::NamedSequenceOp::getEffects(
2400  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2401 
2403  OperationState &result) {
2405  parser, result, /*allowVariadic=*/false,
2406  getFunctionTypeAttrName(result.name),
2407  [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2409  std::string &) { return builder.getFunctionType(inputs, results); },
2410  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2411 }
2412 
2415  printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2416  getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2417  getResAttrsAttrName());
2418 }
2419 
2420 /// Verifies that a symbol function-like transform dialect operation has the
2421 /// signature and the terminator that have conforming types, i.e., types
2422 /// implementing the same transform dialect type interface. If `allowExternal`
2423 /// is set, allow external symbols (declarations) and don't check the terminator
2424 /// as it may not exist.
2426 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2427  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2430  << "cannot be defined inside another transform op";
2431  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2432  return diag;
2433  }
2434 
2435  if (op.isExternal() || op.getFunctionBody().empty()) {
2436  if (allowExternal)
2438 
2439  return emitSilenceableFailure(op) << "cannot be external";
2440  }
2441 
2442  if (op.getFunctionBody().front().empty())
2443  return emitSilenceableFailure(op) << "expected a non-empty body block";
2444 
2445  Operation *terminator = &op.getFunctionBody().front().back();
2446  if (!isa<transform::YieldOp>(terminator)) {
2448  << "expected '"
2449  << transform::YieldOp::getOperationName()
2450  << "' as terminator";
2451  diag.attachNote(terminator->getLoc()) << "terminator";
2452  return diag;
2453  }
2454 
2455  if (terminator->getNumOperands() != op.getResultTypes().size()) {
2456  return emitSilenceableFailure(terminator)
2457  << "expected terminator to have as many operands as the parent op "
2458  "has results";
2459  }
2460  for (auto [i, operandType, resultType] : llvm::zip_equal(
2461  llvm::seq<unsigned>(0, terminator->getNumOperands()),
2462  terminator->getOperands().getType(), op.getResultTypes())) {
2463  if (operandType == resultType)
2464  continue;
2465  return emitSilenceableFailure(terminator)
2466  << "the type of the terminator operand #" << i
2467  << " must match the type of the corresponding parent op result ("
2468  << operandType << " vs " << resultType << ")";
2469  }
2470 
2472 }
2473 
2474 /// Verification of a NamedSequenceOp. This does not report the error
2475 /// immediately, so it can be used to check for op's well-formedness before the
2476 /// verifier runs, e.g., during trait verification.
2478 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2479  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2480  if (!parent->getAttr(
2481  transform::TransformDialect::kWithNamedSequenceAttrName)) {
2484  << "expects the parent symbol table to have the '"
2485  << transform::TransformDialect::kWithNamedSequenceAttrName
2486  << "' attribute";
2487  diag.attachNote(parent->getLoc()) << "symbol table operation";
2488  return diag;
2489  }
2490  }
2491 
2492  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2495  << "cannot be defined inside another transform op";
2496  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2497  return diag;
2498  }
2499 
2500  if (op.isExternal() || op.getBody().empty())
2501  return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2502  emitWarnings);
2503 
2504  if (op.getBody().front().empty())
2505  return emitSilenceableFailure(op) << "expected a non-empty body block";
2506 
2507  Operation *terminator = &op.getBody().front().back();
2508  if (!isa<transform::YieldOp>(terminator)) {
2510  << "expected '"
2511  << transform::YieldOp::getOperationName()
2512  << "' as terminator";
2513  diag.attachNote(terminator->getLoc()) << "terminator";
2514  return diag;
2515  }
2516 
2517  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2518  return emitSilenceableFailure(terminator)
2519  << "expected terminator to have as many operands as the parent op "
2520  "has results";
2521  }
2522  for (auto [i, operandType, resultType] :
2523  llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2524  terminator->getOperands().getType(),
2525  op.getFunctionType().getResults())) {
2526  if (operandType == resultType)
2527  continue;
2528  return emitSilenceableFailure(terminator)
2529  << "the type of the terminator operand #" << i
2530  << " must match the type of the corresponding parent op result ("
2531  << operandType << " vs " << resultType << ")";
2532  }
2533 
2534  auto funcOp = cast<FunctionOpInterface>(*op);
2536  verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2537  if (!diag.succeeded())
2538  return diag;
2539 
2540  return verifyYieldingSingleBlockOp(funcOp,
2541  /*allowExternal=*/true);
2542 }
2543 
2544 LogicalResult transform::NamedSequenceOp::verify() {
2545  // Actual verification happens in a separate function for reusability.
2546  return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2547 }
2548 
2549 template <typename FnTy>
2550 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2551  Type bbArgType, TypeRange extraBindingTypes,
2552  FnTy bodyBuilder) {
2553  SmallVector<Type> types;
2554  types.reserve(1 + extraBindingTypes.size());
2555  types.push_back(bbArgType);
2556  llvm::append_range(types, extraBindingTypes);
2557 
2558  OpBuilder::InsertionGuard guard(builder);
2559  Region *region = state.regions.back().get();
2560  Block *bodyBlock =
2561  builder.createBlock(region, region->begin(), types,
2562  SmallVector<Location>(types.size(), state.location));
2563 
2564  // Populate body.
2565  builder.setInsertionPointToStart(bodyBlock);
2566  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2567  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2568  } else {
2569  bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2570  bodyBlock->getArguments().drop_front());
2571  }
2572 }
2573 
2574 void transform::NamedSequenceOp::build(OpBuilder &builder,
2575  OperationState &state, StringRef symName,
2576  Type rootType, TypeRange resultTypes,
2577  SequenceBodyBuilderFn bodyBuilder,
2579  ArrayRef<DictionaryAttr> argAttrs) {
2580  state.addAttribute(SymbolTable::getSymbolAttrName(),
2581  builder.getStringAttr(symName));
2582  state.addAttribute(getFunctionTypeAttrName(state.name),
2584  rootType, resultTypes)));
2585  state.attributes.append(attrs.begin(), attrs.end());
2586  state.addRegion();
2587 
2588  buildSequenceBody(builder, state, rootType,
2589  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2590 }
2591 
2592 //===----------------------------------------------------------------------===//
2593 // NumAssociationsOp
2594 //===----------------------------------------------------------------------===//
2595 
2597 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2598  transform::TransformResults &results,
2599  transform::TransformState &state) {
2600  size_t numAssociations =
2602  .Case([&](TransformHandleTypeInterface opHandle) {
2603  return llvm::range_size(state.getPayloadOps(getHandle()));
2604  })
2605  .Case([&](TransformValueHandleTypeInterface valueHandle) {
2606  return llvm::range_size(state.getPayloadValues(getHandle()));
2607  })
2608  .Case([&](TransformParamTypeInterface param) {
2609  return llvm::range_size(state.getParams(getHandle()));
2610  })
2611  .Default([](Type) {
2612  llvm_unreachable("unknown kind of transform dialect type");
2613  return 0;
2614  });
2615  results.setParams(cast<OpResult>(getNum()),
2616  rewriter.getI64IntegerAttr(numAssociations));
2618 }
2619 
2620 LogicalResult transform::NumAssociationsOp::verify() {
2621  // Verify that the result type accepts an i64 attribute as payload.
2622  auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2623  return resultType
2624  .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2625  .checkAndReport();
2626 }
2627 
2628 //===----------------------------------------------------------------------===//
2629 // SelectOp
2630 //===----------------------------------------------------------------------===//
2631 
2633 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2634  transform::TransformResults &results,
2635  transform::TransformState &state) {
2636  SmallVector<Operation *> result;
2637  auto payloadOps = state.getPayloadOps(getTarget());
2638  for (Operation *op : payloadOps) {
2639  if (op->getName().getStringRef() == getOpName())
2640  result.push_back(op);
2641  }
2642  results.set(cast<OpResult>(getResult()), result);
2644 }
2645 
2646 //===----------------------------------------------------------------------===//
2647 // SplitHandleOp
2648 //===----------------------------------------------------------------------===//
2649 
2650 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2651  Value target, int64_t numResultHandles) {
2652  result.addOperands(target);
2653  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2654 }
2655 
2657 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2658  transform::TransformResults &results,
2659  transform::TransformState &state) {
2660  int64_t numPayloads =
2662  .Case<TransformHandleTypeInterface>([&](auto x) {
2663  return llvm::range_size(state.getPayloadOps(getHandle()));
2664  })
2665  .Case<TransformValueHandleTypeInterface>([&](auto x) {
2666  return llvm::range_size(state.getPayloadValues(getHandle()));
2667  })
2668  .Case<TransformParamTypeInterface>([&](auto x) {
2669  return llvm::range_size(state.getParams(getHandle()));
2670  })
2671  .Default([](auto x) {
2672  llvm_unreachable("unknown transform dialect type interface");
2673  return -1;
2674  });
2675 
2676  auto produceNumOpsError = [&]() {
2677  return emitSilenceableError()
2678  << getHandle() << " expected to contain " << this->getNumResults()
2679  << " payloads but it contains " << numPayloads << " payloads";
2680  };
2681 
2682  // Fail if there are more payload ops than results and no overflow result was
2683  // specified.
2684  if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2685  return produceNumOpsError();
2686 
2687  // Fail if there are more results than payload ops. Unless:
2688  // - "fail_on_payload_too_small" is set to "false", or
2689  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2690  if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2691  (numPayloads != 0 || !getPassThroughEmptyHandle()))
2692  return produceNumOpsError();
2693 
2694  // Distribute payloads.
2695  SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2696  if (getOverflowResult())
2697  resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2698 
2699  auto container = [&]() {
2700  if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2701  return llvm::map_to_vector(
2702  state.getPayloadOps(getHandle()),
2703  [](Operation *op) -> MappedValue { return op; });
2704  }
2705  if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2706  return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2707  [](Value v) -> MappedValue { return v; });
2708  }
2709  assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2710  "unsupported kind of transform dialect type");
2711  return llvm::map_to_vector(state.getParams(getHandle()),
2712  [](Attribute a) -> MappedValue { return a; });
2713  }();
2714 
2715  for (auto &&en : llvm::enumerate(container)) {
2716  int64_t resultNum = en.index();
2717  if (resultNum >= getNumResults())
2718  resultNum = *getOverflowResult();
2719  resultHandles[resultNum].push_back(en.value());
2720  }
2721 
2722  // Set transform op results.
2723  for (auto &&it : llvm::enumerate(resultHandles))
2724  results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2725  it.value());
2726 
2728 }
2729 
2730 void transform::SplitHandleOp::getEffects(
2731  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2732  onlyReadsHandle(getHandleMutable(), effects);
2733  producesHandle(getOperation()->getOpResults(), effects);
2734  // There are no effects on the Payload IR as this is only a handle
2735  // manipulation.
2736 }
2737 
2738 LogicalResult transform::SplitHandleOp::verify() {
2739  if (getOverflowResult().has_value() &&
2740  !(*getOverflowResult() < getNumResults()))
2741  return emitOpError("overflow_result is not a valid result index");
2742 
2743  for (Type resultType : getResultTypes()) {
2744  if (implementSameTransformInterface(getHandle().getType(), resultType))
2745  continue;
2746 
2747  return emitOpError("expects result types to implement the same transform "
2748  "interface as the operand type");
2749  }
2750 
2751  return success();
2752 }
2753 
2754 //===----------------------------------------------------------------------===//
2755 // ReplicateOp
2756 //===----------------------------------------------------------------------===//
2757 
2759 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2760  transform::TransformResults &results,
2761  transform::TransformState &state) {
2762  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2763  for (const auto &en : llvm::enumerate(getHandles())) {
2764  Value handle = en.value();
2765  if (isa<TransformHandleTypeInterface>(handle.getType())) {
2766  SmallVector<Operation *> current =
2767  llvm::to_vector(state.getPayloadOps(handle));
2768  SmallVector<Operation *> payload;
2769  payload.reserve(numRepetitions * current.size());
2770  for (unsigned i = 0; i < numRepetitions; ++i)
2771  llvm::append_range(payload, current);
2772  results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2773  } else {
2774  assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2775  "expected param type");
2776  ArrayRef<Attribute> current = state.getParams(handle);
2777  SmallVector<Attribute> params;
2778  params.reserve(numRepetitions * current.size());
2779  for (unsigned i = 0; i < numRepetitions; ++i)
2780  llvm::append_range(params, current);
2781  results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2782  params);
2783  }
2784  }
2786 }
2787 
2788 void transform::ReplicateOp::getEffects(
2789  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2790  onlyReadsHandle(getPatternMutable(), effects);
2791  onlyReadsHandle(getHandlesMutable(), effects);
2792  producesHandle(getOperation()->getOpResults(), effects);
2793 }
2794 
2795 //===----------------------------------------------------------------------===//
2796 // SequenceOp
2797 //===----------------------------------------------------------------------===//
2798 
2800 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2801  transform::TransformResults &results,
2802  transform::TransformState &state) {
2803  // Map the entry block argument to the list of operations.
2804  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2805  if (failed(mapBlockArguments(state)))
2807 
2808  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2809  results);
2810 }
2811 
2812 static ParseResult parseSequenceOpOperands(
2813  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2814  Type &rootType,
2815  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2816  SmallVectorImpl<Type> &extraBindingTypes) {
2817  OpAsmParser::UnresolvedOperand rootOperand;
2818  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2819  if (!hasRoot.has_value()) {
2820  root = std::nullopt;
2821  return success();
2822  }
2823  if (failed(hasRoot.value()))
2824  return failure();
2825  root = rootOperand;
2826 
2827  if (succeeded(parser.parseOptionalComma())) {
2828  if (failed(parser.parseOperandList(extraBindings)))
2829  return failure();
2830  }
2831  if (failed(parser.parseColon()))
2832  return failure();
2833 
2834  // The paren is truly optional.
2835  (void)parser.parseOptionalLParen();
2836 
2837  if (failed(parser.parseType(rootType))) {
2838  return failure();
2839  }
2840 
2841  if (!extraBindings.empty()) {
2842  if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2843  return failure();
2844  }
2845 
2846  if (extraBindingTypes.size() != extraBindings.size()) {
2847  return parser.emitError(parser.getNameLoc(),
2848  "expected types to be provided for all operands");
2849  }
2850 
2851  // The paren is truly optional.
2852  (void)parser.parseOptionalRParen();
2853  return success();
2854 }
2855 
2857  Value root, Type rootType,
2858  ValueRange extraBindings,
2859  TypeRange extraBindingTypes) {
2860  if (!root)
2861  return;
2862 
2863  printer << root;
2864  bool hasExtras = !extraBindings.empty();
2865  if (hasExtras) {
2866  printer << ", ";
2867  printer.printOperands(extraBindings);
2868  }
2869 
2870  printer << " : ";
2871  if (hasExtras)
2872  printer << "(";
2873 
2874  printer << rootType;
2875  if (hasExtras)
2876  printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2877 }
2878 
2879 /// Returns `true` if the given op operand may be consuming the handle value in
2880 /// the Transform IR. That is, if it may have a Free effect on it.
2882  // Conservatively assume the effect being present in absence of the interface.
2883  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2884  if (!iface)
2885  return true;
2886 
2887  return isHandleConsumed(use.get(), iface);
2888 }
2889 
2890 LogicalResult
2892  function_ref<InFlightDiagnostic()> reportError) {
2893  OpOperand *potentialConsumer = nullptr;
2894  for (OpOperand &use : value.getUses()) {
2895  if (!isValueUsePotentialConsumer(use))
2896  continue;
2897 
2898  if (!potentialConsumer) {
2899  potentialConsumer = &use;
2900  continue;
2901  }
2902 
2903  InFlightDiagnostic diag = reportError()
2904  << " has more than one potential consumer";
2905  diag.attachNote(potentialConsumer->getOwner()->getLoc())
2906  << "used here as operand #" << potentialConsumer->getOperandNumber();
2907  diag.attachNote(use.getOwner()->getLoc())
2908  << "used here as operand #" << use.getOperandNumber();
2909  return diag;
2910  }
2911 
2912  return success();
2913 }
2914 
2915 LogicalResult transform::SequenceOp::verify() {
2916  assert(getBodyBlock()->getNumArguments() >= 1 &&
2917  "the number of arguments must have been verified to be more than 1 by "
2918  "PossibleTopLevelTransformOpTrait");
2919 
2920  if (!getRoot() && !getExtraBindings().empty()) {
2921  return emitOpError()
2922  << "does not expect extra operands when used as top-level";
2923  }
2924 
2925  // Check if a block argument has more than one consuming use.
2926  for (BlockArgument arg : getBodyBlock()->getArguments()) {
2927  if (failed(checkDoubleConsume(arg, [this, arg]() {
2928  return (emitOpError() << "block argument #" << arg.getArgNumber());
2929  }))) {
2930  return failure();
2931  }
2932  }
2933 
2934  // Check properties of the nested operations they cannot check themselves.
2935  for (Operation &child : *getBodyBlock()) {
2936  if (!isa<TransformOpInterface>(child) &&
2937  &child != &getBodyBlock()->back()) {
2939  emitOpError()
2940  << "expected children ops to implement TransformOpInterface";
2941  diag.attachNote(child.getLoc()) << "op without interface";
2942  return diag;
2943  }
2944 
2945  for (OpResult result : child.getResults()) {
2946  auto report = [&]() {
2947  return (child.emitError() << "result #" << result.getResultNumber());
2948  };
2949  if (failed(checkDoubleConsume(result, report)))
2950  return failure();
2951  }
2952  }
2953 
2954  if (!getBodyBlock()->mightHaveTerminator())
2955  return emitOpError() << "expects to have a terminator in the body";
2956 
2957  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2958  getOperation()->getResultTypes()) {
2959  InFlightDiagnostic diag = emitOpError()
2960  << "expects the types of the terminator operands "
2961  "to match the types of the result";
2962  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2963  return diag;
2964  }
2965  return success();
2966 }
2967 
2968 void transform::SequenceOp::getEffects(
2969  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2970  getPotentialTopLevelEffects(effects);
2971 }
2972 
2974 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2975  assert(point == getBody() && "unexpected region index");
2976  if (getOperation()->getNumOperands() > 0)
2977  return getOperation()->getOperands();
2978  return OperandRange(getOperation()->operand_end(),
2979  getOperation()->operand_end());
2980 }
2981 
2982 void transform::SequenceOp::getSuccessorRegions(
2983  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2984  if (point.isParent()) {
2985  Region *bodyRegion = &getBody();
2986  regions.emplace_back(bodyRegion, getNumOperands() != 0
2987  ? bodyRegion->getArguments()
2989  return;
2990  }
2991 
2992  assert(point == getBody() && "unexpected region index");
2993  regions.emplace_back(getOperation()->getResults());
2994 }
2995 
2996 void transform::SequenceOp::getRegionInvocationBounds(
2997  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2998  (void)operands;
2999  bounds.emplace_back(1, 1);
3000 }
3001 
3002 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3003  TypeRange resultTypes,
3004  FailurePropagationMode failurePropagationMode,
3005  Value root,
3006  SequenceBodyBuilderFn bodyBuilder) {
3007  build(builder, state, resultTypes, failurePropagationMode, root,
3008  /*extra_bindings=*/ValueRange());
3009  Type bbArgType = root.getType();
3010  buildSequenceBody(builder, state, bbArgType,
3011  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3012 }
3013 
3014 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3015  TypeRange resultTypes,
3016  FailurePropagationMode failurePropagationMode,
3017  Value root, ValueRange extraBindings,
3018  SequenceBodyBuilderArgsFn bodyBuilder) {
3019  build(builder, state, resultTypes, failurePropagationMode, root,
3020  extraBindings);
3021  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
3022  bodyBuilder);
3023 }
3024 
3025 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3026  TypeRange resultTypes,
3027  FailurePropagationMode failurePropagationMode,
3028  Type bbArgType,
3029  SequenceBodyBuilderFn bodyBuilder) {
3030  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3031  /*extra_bindings=*/ValueRange());
3032  buildSequenceBody(builder, state, bbArgType,
3033  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3034 }
3035 
3036 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3037  TypeRange resultTypes,
3038  FailurePropagationMode failurePropagationMode,
3039  Type bbArgType, TypeRange extraBindingTypes,
3040  SequenceBodyBuilderArgsFn bodyBuilder) {
3041  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3042  /*extra_bindings=*/ValueRange());
3043  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3044 }
3045 
3046 //===----------------------------------------------------------------------===//
3047 // PrintOp
3048 //===----------------------------------------------------------------------===//
3049 
3050 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3051  StringRef name) {
3052  if (!name.empty())
3053  result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3054 }
3055 
3056 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3057  Value target, StringRef name) {
3058  result.addOperands({target});
3059  build(builder, result, name);
3060 }
3061 
3063 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3064  transform::TransformResults &results,
3065  transform::TransformState &state) {
3066  llvm::outs() << "[[[ IR printer: ";
3067  if (getName().has_value())
3068  llvm::outs() << *getName() << " ";
3069 
3070  OpPrintingFlags printFlags;
3071  if (getAssumeVerified().value_or(false))
3072  printFlags.assumeVerified();
3073  if (getUseLocalScope().value_or(false))
3074  printFlags.useLocalScope();
3075  if (getSkipRegions().value_or(false))
3076  printFlags.skipRegions();
3077 
3078  if (!getTarget()) {
3079  llvm::outs() << "top-level ]]]\n";
3080  state.getTopLevel()->print(llvm::outs(), printFlags);
3081  llvm::outs() << "\n";
3082  llvm::outs().flush();
3084  }
3085 
3086  llvm::outs() << "]]]\n";
3087  for (Operation *target : state.getPayloadOps(getTarget())) {
3088  target->print(llvm::outs(), printFlags);
3089  llvm::outs() << "\n";
3090  }
3091 
3092  llvm::outs().flush();
3094 }
3095 
3096 void transform::PrintOp::getEffects(
3097  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3098  // We don't really care about mutability here, but `getTarget` now
3099  // unconditionally casts to a specific type before verification could run
3100  // here.
3101  if (!getTargetMutable().empty())
3102  onlyReadsHandle(getTargetMutable()[0], effects);
3103  onlyReadsPayload(effects);
3104 
3105  // There is no resource for stderr file descriptor, so just declare print
3106  // writes into the default resource.
3107  effects.emplace_back(MemoryEffects::Write::get());
3108 }
3109 
3110 //===----------------------------------------------------------------------===//
3111 // VerifyOp
3112 //===----------------------------------------------------------------------===//
3113 
3115 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3116  Operation *target,
3118  transform::TransformState &state) {
3119  if (failed(::mlir::verify(target))) {
3121  << "failed to verify payload op";
3122  diag.attachNote(target->getLoc()) << "payload op";
3123  return diag;
3124  }
3126 }
3127 
3128 void transform::VerifyOp::getEffects(
3129  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3130  transform::onlyReadsHandle(getTargetMutable(), effects);
3131 }
3132 
3133 //===----------------------------------------------------------------------===//
3134 // YieldOp
3135 //===----------------------------------------------------------------------===//
3136 
3137 void transform::YieldOp::getEffects(
3138  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3139  onlyReadsHandle(getOperandsMutable(), effects);
3140 }
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a #dlti.dl_entry attribute.
Definition: DLTI.cpp:40
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue >> blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static ParseResult parseApplyRegisteredPassOptions(OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dynamicOptions)
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DBGS_MATCHER()
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
#define DEBUG_MATCHER(x)
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
#define DBGS()
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, Operation *op, DictionaryAttr options, ValueRange dynamicOptions)
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:78
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:110
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:264
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:729
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:297
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
Definition: AsmPrinter.cpp:289
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:283
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getType() const
Definition: ValueRange.cpp:32
type_range getTypes() const
Definition: ValueRange.cpp:28
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
Definition: PassManager.h:232
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:52
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
Definition: PassRegistry.h:58
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
static void printOptionValue(raw_ostream &os, const bool &value)
Utility methods for printing option values.
Definition: PassOptions.h:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument, ::mlir::ValueRange)> SequenceBodyBuilderArgsFn
Definition: TransformOps.h:39
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)> SequenceBodyBuilderFn
A builder function that populates the body of a SequenceOp.
Definition: TransformOps.h:36
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition: CSE.cpp:382
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.