MLIR  22.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/Verifier.h"
28 #include "mlir/Pass/PassManager.h"
29 #include "mlir/Pass/PassRegistry.h"
30 #include "mlir/Transforms/CSE.h"
34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/SmallPtrSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/DebugLog.h"
41 #include "llvm/Support/ErrorHandling.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include <optional>
44 
45 #define DEBUG_TYPE "transform-dialect"
46 #define DEBUG_TYPE_MATCHER "transform-matcher"
47 
48 using namespace mlir;
49 
50 static ParseResult parseApplyRegisteredPassOptions(
51  OpAsmParser &parser, DictionaryAttr &options,
52  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
54  Operation *op,
55  DictionaryAttr options,
56  ValueRange dynamicOptions);
57 static ParseResult parseSequenceOpOperands(
58  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
59  Type &rootType,
60  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
61  SmallVectorImpl<Type> &extraBindingTypes);
62 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
63  Value root, Type rootType,
64  ValueRange extraBindings,
65  TypeRange extraBindingTypes);
66 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
67  ArrayAttr matchers, ArrayAttr actions);
68 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
69  ArrayAttr &matchers,
70  ArrayAttr &actions);
71 
72 /// Helper function to check if the given transform op is contained in (or
73 /// equal to) the given payload target op. In that case, an error is returned.
74 /// Transforming transform IR that is currently executing is generally unsafe.
76 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
77  Operation *payload) {
78  Operation *transformAncestor = transform.getOperation();
79  while (transformAncestor) {
80  if (transformAncestor == payload) {
82  transform.emitDefiniteFailure()
83  << "cannot apply transform to itself (or one of its ancestors)";
84  diag.attachNote(payload->getLoc()) << "target payload op";
85  return diag;
86  }
87  transformAncestor = transformAncestor->getParentOp();
88  }
90 }
91 
92 #define GET_OP_CLASSES
93 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
94 
95 //===----------------------------------------------------------------------===//
96 // AlternativesOp
97 //===----------------------------------------------------------------------===//
98 
100 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
101  if (!point.isParent() && getOperation()->getNumOperands() == 1)
102  return getOperation()->getOperands();
103  return OperandRange(getOperation()->operand_end(),
104  getOperation()->operand_end());
105 }
106 
107 void transform::AlternativesOp::getSuccessorRegions(
108  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
109  for (Region &alternative : llvm::drop_begin(
110  getAlternatives(),
111  point.isParent() ? 0
112  : point.getRegionOrNull()->getRegionNumber() + 1)) {
113  regions.emplace_back(&alternative, !getOperands().empty()
114  ? alternative.getArguments()
116  }
117  if (!point.isParent())
118  regions.emplace_back(getOperation()->getResults());
119 }
120 
121 void transform::AlternativesOp::getRegionInvocationBounds(
122  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
123  (void)operands;
124  // The region corresponding to the first alternative is always executed, the
125  // remaining may or may not be executed.
126  bounds.reserve(getNumRegions());
127  bounds.emplace_back(1, 1);
128  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
129 }
130 
132  transform::TransformResults &results) {
133  for (const auto &res : block->getParentOp()->getOpResults())
134  results.set(res, {});
135 }
136 
138 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
140  transform::TransformState &state) {
141  SmallVector<Operation *> originals;
142  if (Value scopeHandle = getScope())
143  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
144  else
145  originals.push_back(state.getTopLevel());
146 
147  for (Operation *original : originals) {
148  if (original->isAncestor(getOperation())) {
149  auto diag = emitDefiniteFailure()
150  << "scope must not contain the transforms being applied";
151  diag.attachNote(original->getLoc()) << "scope";
152  return diag;
153  }
154  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
155  auto diag = emitDefiniteFailure()
156  << "only isolated-from-above ops can be alternative scopes";
157  diag.attachNote(original->getLoc()) << "scope";
158  return diag;
159  }
160  }
161 
162  for (Region &reg : getAlternatives()) {
163  // Clone the scope operations and make the transforms in this alternative
164  // region apply to them by virtue of mapping the block argument (the only
165  // visible handle) to the cloned scope operations. This effectively prevents
166  // the transformation from accessing any IR outside the scope.
167  auto scope = state.make_region_scope(reg);
168  auto clones = llvm::to_vector(
169  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
170  auto deleteClones = llvm::make_scope_exit([&] {
171  for (Operation *clone : clones)
172  clone->erase();
173  });
174  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
176 
177  bool failed = false;
178  for (Operation &transform : reg.front().without_terminator()) {
180  state.applyTransform(cast<TransformOpInterface>(transform));
181  if (result.isSilenceableFailure()) {
182  LDBG() << "alternative failed: " << result.getMessage();
183  failed = true;
184  break;
185  }
186 
187  if (::mlir::failed(result.silence()))
189  }
190 
191  // If all operations in the given alternative succeeded, no need to consider
192  // the rest. Replace the original scoping operation with the clone on which
193  // the transformations were performed.
194  if (!failed) {
195  // We will be using the clones, so cancel their scheduled deletion.
196  deleteClones.release();
197  TrackingListener listener(state, *this);
198  IRRewriter rewriter(getContext(), &listener);
199  for (const auto &kvp : llvm::zip(originals, clones)) {
200  Operation *original = std::get<0>(kvp);
201  Operation *clone = std::get<1>(kvp);
202  original->getBlock()->getOperations().insert(original->getIterator(),
203  clone);
204  rewriter.replaceOp(original, clone->getResults());
205  }
206  detail::forwardTerminatorOperands(&reg.front(), state, results);
208  }
209  }
210  return emitSilenceableError() << "all alternatives failed";
211 }
212 
213 void transform::AlternativesOp::getEffects(
214  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
215  consumesHandle(getOperation()->getOpOperands(), effects);
216  producesHandle(getOperation()->getOpResults(), effects);
217  for (Region *region : getRegions()) {
218  if (!region->empty())
219  producesHandle(region->front().getArguments(), effects);
220  }
221  modifiesPayload(effects);
222 }
223 
224 LogicalResult transform::AlternativesOp::verify() {
225  for (Region &alternative : getAlternatives()) {
226  Block &block = alternative.front();
227  Operation *terminator = block.getTerminator();
228  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
229  InFlightDiagnostic diag = emitOpError()
230  << "expects terminator operands to have the "
231  "same type as results of the operation";
232  diag.attachNote(terminator->getLoc()) << "terminator";
233  return diag;
234  }
235  }
236 
237  return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // AnnotateOp
242 //===----------------------------------------------------------------------===//
243 
245 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
247  transform::TransformState &state) {
248  SmallVector<Operation *> targets =
249  llvm::to_vector(state.getPayloadOps(getTarget()));
250 
252  if (auto paramH = getParam()) {
253  ArrayRef<Attribute> params = state.getParams(paramH);
254  if (params.size() != 1) {
255  if (targets.size() != params.size()) {
256  return emitSilenceableError()
257  << "parameter and target have different payload lengths ("
258  << params.size() << " vs " << targets.size() << ")";
259  }
260  for (auto &&[target, attr] : llvm::zip_equal(targets, params))
261  target->setAttr(getName(), attr);
263  }
264  attr = params[0];
265  }
266  for (auto *target : targets)
267  target->setAttr(getName(), attr);
269 }
270 
271 void transform::AnnotateOp::getEffects(
272  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
273  onlyReadsHandle(getTargetMutable(), effects);
274  onlyReadsHandle(getParamMutable(), effects);
275  modifiesPayload(effects);
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // ApplyCommonSubexpressionEliminationOp
280 //===----------------------------------------------------------------------===//
281 
283 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
284  transform::TransformRewriter &rewriter, Operation *target,
285  ApplyToEachResultList &results, transform::TransformState &state) {
286  // Make sure that this transform is not applied to itself. Modifying the
287  // transform IR while it is being interpreted is generally dangerous.
288  DiagnosedSilenceableFailure payloadCheck =
290  if (!payloadCheck.succeeded())
291  return payloadCheck;
292 
293  DominanceInfo domInfo;
294  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
296 }
297 
298 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
299  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
300  transform::onlyReadsHandle(getTargetMutable(), effects);
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // ApplyDeadCodeEliminationOp
306 //===----------------------------------------------------------------------===//
307 
308 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
309  transform::TransformRewriter &rewriter, Operation *target,
310  ApplyToEachResultList &results, transform::TransformState &state) {
311  // Make sure that this transform is not applied to itself. Modifying the
312  // transform IR while it is being interpreted is generally dangerous.
313  DiagnosedSilenceableFailure payloadCheck =
315  if (!payloadCheck.succeeded())
316  return payloadCheck;
317 
318  // Maintain a worklist of potentially dead ops.
319  SetVector<Operation *> worklist;
320 
321  // Helper function that adds all defining ops of used values (operands and
322  // operands of nested ops).
323  auto addDefiningOpsToWorklist = [&](Operation *op) {
324  op->walk([&](Operation *op) {
325  for (Value v : op->getOperands())
326  if (Operation *defOp = v.getDefiningOp())
327  if (target->isProperAncestor(defOp))
328  worklist.insert(defOp);
329  });
330  };
331 
332  // Helper function that erases an op.
333  auto eraseOp = [&](Operation *op) {
334  // Remove op and nested ops from the worklist.
335  op->walk([&](Operation *op) {
336  const auto *it = llvm::find(worklist, op);
337  if (it != worklist.end())
338  worklist.erase(it);
339  });
340  rewriter.eraseOp(op);
341  };
342 
343  // Initial walk over the IR.
344  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
345  if (op != target && isOpTriviallyDead(op)) {
346  addDefiningOpsToWorklist(op);
347  eraseOp(op);
348  }
349  });
350 
351  // Erase all ops that have become dead.
352  while (!worklist.empty()) {
353  Operation *op = worklist.pop_back_val();
354  if (!isOpTriviallyDead(op))
355  continue;
356  addDefiningOpsToWorklist(op);
357  eraseOp(op);
358  }
359 
361 }
362 
363 void transform::ApplyDeadCodeEliminationOp::getEffects(
364  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
365  transform::onlyReadsHandle(getTargetMutable(), effects);
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // ApplyPatternsOp
371 //===----------------------------------------------------------------------===//
372 
373 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
374  transform::TransformRewriter &rewriter, Operation *target,
375  ApplyToEachResultList &results, transform::TransformState &state) {
376  // Make sure that this transform is not applied to itself. Modifying the
377  // transform IR while it is being interpreted is generally dangerous. Even
378  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
379  // performs many additional simplifications such as dead code elimination.
380  DiagnosedSilenceableFailure payloadCheck =
382  if (!payloadCheck.succeeded())
383  return payloadCheck;
384 
385  // Gather all specified patterns.
386  MLIRContext *ctx = target->getContext();
388  if (!getRegion().empty()) {
389  for (Operation &op : getRegion().front()) {
390  cast<transform::PatternDescriptorOpInterface>(&op)
391  .populatePatternsWithState(patterns, state);
392  }
393  }
394 
395  // Configure the GreedyPatternRewriteDriver.
397  config.setListener(
398  static_cast<RewriterBase::Listener *>(rewriter.getListener()));
399  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
400 
401  config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
403  : getMaxIterations());
404  config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
406  : getMaxNumRewrites());
407 
408  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
409  // was requested, apply the greedy pattern rewrite only once. (The greedy
410  // pattern rewrite driver already iterates to a fixpoint internally.)
411  bool cseChanged = false;
412  // One or two iterations should be sufficient. Stop iterating after a certain
413  // threshold to make debugging easier.
414  static const int64_t kNumMaxIterations = 50;
415  int64_t iteration = 0;
416  do {
417  LogicalResult result = failure();
418  if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
419  // Op is isolated from above. Apply patterns and also perform region
420  // simplification.
421  result = applyPatternsGreedily(target, frozenPatterns, config);
422  } else {
423  // Manually gather list of ops because the other
424  // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
425  // from above. This way, patterns can be applied to ops that are not
426  // isolated from above. Regions are not being simplified. Furthermore,
427  // only a single greedy rewrite iteration is performed.
429  target->walk([&](Operation *nestedOp) {
430  if (target != nestedOp)
431  ops.push_back(nestedOp);
432  });
433  result = applyOpPatternsGreedily(ops, frozenPatterns, config);
434  }
435 
436  // A failure typically indicates that the pattern application did not
437  // converge.
438  if (failed(result)) {
439  return emitSilenceableFailure(target)
440  << "greedy pattern application failed";
441  }
442 
443  if (getApplyCse()) {
444  DominanceInfo domInfo;
445  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
446  &cseChanged);
447  }
448  } while (cseChanged && ++iteration < kNumMaxIterations);
449 
450  if (iteration == kNumMaxIterations)
451  return emitDefiniteFailure() << "fixpoint iteration did not converge";
452 
454 }
455 
456 LogicalResult transform::ApplyPatternsOp::verify() {
457  if (!getRegion().empty()) {
458  for (Operation &op : getRegion().front()) {
459  if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
460  InFlightDiagnostic diag = emitOpError()
461  << "expected children ops to implement "
462  "PatternDescriptorOpInterface";
463  diag.attachNote(op.getLoc()) << "op without interface";
464  return diag;
465  }
466  }
467  }
468  return success();
469 }
470 
471 void transform::ApplyPatternsOp::getEffects(
472  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
473  transform::onlyReadsHandle(getTargetMutable(), effects);
475 }
476 
477 void transform::ApplyPatternsOp::build(
478  OpBuilder &builder, OperationState &result, Value target,
479  function_ref<void(OpBuilder &, Location)> bodyBuilder) {
480  result.addOperands(target);
481 
482  OpBuilder::InsertionGuard g(builder);
483  Region *region = result.addRegion();
484  builder.createBlock(region);
485  if (bodyBuilder)
486  bodyBuilder(builder, result.location);
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // ApplyCanonicalizationPatternsOp
491 //===----------------------------------------------------------------------===//
492 
493 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
495  MLIRContext *ctx = patterns.getContext();
496  for (Dialect *dialect : ctx->getLoadedDialects())
497  dialect->getCanonicalizationPatterns(patterns);
499  op.getCanonicalizationPatterns(patterns, ctx);
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // ApplyConversionPatternsOp
504 //===----------------------------------------------------------------------===//
505 
506 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
509  MLIRContext *ctx = getContext();
510 
511  // Instantiate the default type converter if a type converter builder is
512  // specified.
513  std::unique_ptr<TypeConverter> defaultTypeConverter;
514  transform::TypeConverterBuilderOpInterface typeConverterBuilder =
515  getDefaultTypeConverter();
516  if (typeConverterBuilder)
517  defaultTypeConverter = typeConverterBuilder.getTypeConverter();
518 
519  // Configure conversion target.
520  ConversionTarget conversionTarget(*getContext());
521  if (getLegalOps())
522  for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
523  conversionTarget.addLegalOp(
524  OperationName(cast<StringAttr>(attr).getValue(), ctx));
525  if (getIllegalOps())
526  for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
527  conversionTarget.addIllegalOp(
528  OperationName(cast<StringAttr>(attr).getValue(), ctx));
529  if (getLegalDialects())
530  for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
531  conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
532  if (getIllegalDialects())
533  for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
534  conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
535 
536  // Gather all specified patterns.
538  // Need to keep the converters alive until after pattern application because
539  // the patterns take a reference to an object that would otherwise get out of
540  // scope.
541  SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
542  if (!getPatterns().empty()) {
543  for (Operation &op : getPatterns().front()) {
544  auto descriptor =
545  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
546 
547  // Check if this pattern set specifies a type converter.
548  std::unique_ptr<TypeConverter> typeConverter =
549  descriptor.getTypeConverter();
550  TypeConverter *converter = nullptr;
551  if (typeConverter) {
552  keepAliveConverters.emplace_back(std::move(typeConverter));
553  converter = keepAliveConverters.back().get();
554  } else {
555  // No type converter specified: Use the default type converter.
556  if (!defaultTypeConverter) {
557  auto diag = emitDefiniteFailure()
558  << "pattern descriptor does not specify type "
559  "converter and apply_conversion_patterns op has "
560  "no default type converter";
561  diag.attachNote(op.getLoc()) << "pattern descriptor op";
562  return diag;
563  }
564  converter = defaultTypeConverter.get();
565  }
566 
567  // Add descriptor-specific updates to the conversion target, which may
568  // depend on the final type converter. In structural converters, the
569  // legality of types dictates the dynamic legality of an operation.
570  descriptor.populateConversionTargetRules(*converter, conversionTarget);
571 
572  descriptor.populatePatterns(*converter, patterns);
573  }
574  }
575 
576  // Attach a tracking listener if handles should be preserved. We configure the
577  // listener to allow op replacements with different names, as conversion
578  // patterns typically replace ops with replacement ops that have a different
579  // name.
580  TrackingListenerConfig trackingConfig;
581  trackingConfig.requireMatchingReplacementOpName = false;
582  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
583  ConversionConfig conversionConfig;
584  if (getPreserveHandles())
585  conversionConfig.listener = &trackingListener;
586 
587  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
588  for (Operation *target : state.getPayloadOps(getTarget())) {
589  // Make sure that this transform is not applied to itself. Modifying the
590  // transform IR while it is being interpreted is generally dangerous.
591  DiagnosedSilenceableFailure payloadCheck =
593  if (!payloadCheck.succeeded())
594  return payloadCheck;
595 
596  LogicalResult status = failure();
597  if (getPartialConversion()) {
598  status = applyPartialConversion(target, conversionTarget, frozenPatterns,
599  conversionConfig);
600  } else {
601  status = applyFullConversion(target, conversionTarget, frozenPatterns,
602  conversionConfig);
603  }
604 
605  // Check dialect conversion state.
607  if (failed(status)) {
608  diag = emitSilenceableError() << "dialect conversion failed";
609  diag.attachNote(target->getLoc()) << "target op";
610  }
611 
612  // Check tracking listener error state.
613  DiagnosedSilenceableFailure trackingFailure =
614  trackingListener.checkAndResetError();
615  if (!trackingFailure.succeeded()) {
616  if (diag.succeeded()) {
617  // Tracking failure is the only failure.
618  return trackingFailure;
619  }
620  diag.attachNote() << "tracking listener also failed: "
621  << trackingFailure.getMessage();
622  (void)trackingFailure.silence();
623  }
624 
625  if (!diag.succeeded())
626  return diag;
627  }
628 
630 }
631 
633  if (getNumRegions() != 1 && getNumRegions() != 2)
634  return emitOpError() << "expected 1 or 2 regions";
635  if (!getPatterns().empty()) {
636  for (Operation &op : getPatterns().front()) {
637  if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
639  emitOpError() << "expected pattern children ops to implement "
640  "ConversionPatternDescriptorOpInterface";
641  diag.attachNote(op.getLoc()) << "op without interface";
642  return diag;
643  }
644  }
645  }
646  if (getNumRegions() == 2) {
647  Region &typeConverterRegion = getRegion(1);
648  if (!llvm::hasSingleElement(typeConverterRegion.front()))
649  return emitOpError()
650  << "expected exactly one op in default type converter region";
651  Operation *maybeTypeConverter = &typeConverterRegion.front().front();
652  auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
653  maybeTypeConverter);
654  if (!typeConverterOp) {
655  InFlightDiagnostic diag = emitOpError()
656  << "expected default converter child op to "
657  "implement TypeConverterBuilderOpInterface";
658  diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
659  return diag;
660  }
661  // Check default type converter type.
662  if (!getPatterns().empty()) {
663  for (Operation &op : getPatterns().front()) {
664  auto descriptor =
665  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
666  if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
667  return failure();
668  }
669  }
670  }
671  return success();
672 }
673 
674 void transform::ApplyConversionPatternsOp::getEffects(
675  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
676  if (!getPreserveHandles()) {
677  transform::consumesHandle(getTargetMutable(), effects);
678  } else {
679  transform::onlyReadsHandle(getTargetMutable(), effects);
680  }
682 }
683 
684 void transform::ApplyConversionPatternsOp::build(
685  OpBuilder &builder, OperationState &result, Value target,
686  function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
687  function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
688  result.addOperands(target);
689 
690  {
691  OpBuilder::InsertionGuard g(builder);
692  Region *region1 = result.addRegion();
693  builder.createBlock(region1);
694  if (patternsBodyBuilder)
695  patternsBodyBuilder(builder, result.location);
696  }
697  {
698  OpBuilder::InsertionGuard g(builder);
699  Region *region2 = result.addRegion();
700  builder.createBlock(region2);
701  if (typeConverterBodyBuilder)
702  typeConverterBodyBuilder(builder, result.location);
703  }
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // ApplyToLLVMConversionPatternsOp
708 //===----------------------------------------------------------------------===//
709 
710 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
711  TypeConverter &typeConverter, RewritePatternSet &patterns) {
712  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
713  assert(dialect && "expected that dialect is loaded");
714  auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
715  // ConversionTarget is currently ignored because the enclosing
716  // apply_conversion_patterns op sets up its own ConversionTarget.
717  ConversionTarget target(*getContext());
718  iface->populateConvertToLLVMConversionPatterns(
719  target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
720 }
721 
722 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
723  transform::TypeConverterBuilderOpInterface builder) {
724  if (builder.getTypeConverterType() != "LLVMTypeConverter")
725  return emitOpError("expected LLVMTypeConverter");
726  return success();
727 }
728 
730  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
731  if (!dialect)
732  return emitOpError("unknown dialect or dialect not loaded: ")
733  << getDialectName();
734  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
735  if (!iface)
736  return emitOpError(
737  "dialect does not implement ConvertToLLVMPatternInterface or "
738  "extension was not loaded: ")
739  << getDialectName();
740  return success();
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // ApplyLoopInvariantCodeMotionOp
745 //===----------------------------------------------------------------------===//
746 
748 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
749  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
751  transform::TransformState &state) {
752  // Currently, LICM does not remove operations, so we don't need tracking.
753  // If this ever changes, add a LICM entry point that takes a rewriter.
754  moveLoopInvariantCode(target);
756 }
757 
758 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
759  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
760  transform::onlyReadsHandle(getTargetMutable(), effects);
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // ApplyRegisteredPassOp
766 //===----------------------------------------------------------------------===//
767 
768 void transform::ApplyRegisteredPassOp::getEffects(
769  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
770  consumesHandle(getTargetMutable(), effects);
771  onlyReadsHandle(getDynamicOptionsMutable(), effects);
772  producesHandle(getOperation()->getOpResults(), effects);
773  modifiesPayload(effects);
774 }
775 
777 transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
779  transform::TransformState &state) {
780  // Obtain a single options-string to pass to the pass(-pipeline) from options
781  // passed in as a dictionary of keys mapping to values which are either
782  // attributes or param-operands pointing to attributes.
783  OperandRange dynamicOptions = getDynamicOptions();
784 
785  std::string options;
786  llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
787 
788  // A helper to convert an option's attribute value into a corresponding
789  // string representation, with the ability to obtain the attr(s) from a param.
790  std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
791  if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
792  // The corresponding value attribute(s) is/are passed in via a param.
793  // Obtain the param-operand via its specified index.
794  int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
795  assert(dynamicOptionIdx < static_cast<int64_t>(dynamicOptions.size()) &&
796  "the number of ParamOperandAttrs in the options DictionaryAttr"
797  "should be the same as the number of options passed as params");
798  ArrayRef<Attribute> attrsAssociatedToParam =
799  state.getParams(dynamicOptions[dynamicOptionIdx]);
800  // Recursive so as to append all attrs associated to the param.
801  llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
802  ",");
803  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
804  // Recursive so as to append all nested attrs of the array.
805  llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
806  } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
807  // Convert to unquoted string.
808  optionsStream << strAttr.getValue().str();
809  } else {
810  // For all other attributes, ask the attr to print itself (without type).
811  valueAttr.print(optionsStream, /*elideType=*/true);
812  }
813  };
814 
815  // Convert the options DictionaryAttr into a single string.
816  llvm::interleave(
817  getOptions(), optionsStream,
818  [&](auto namedAttribute) {
819  optionsStream << namedAttribute.getName().str(); // Append the key.
820  optionsStream << "="; // And the key-value separator.
821  appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
822  },
823  " ");
824  optionsStream.flush();
825 
826  // Get pass or pass pipeline from registry.
827  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
828  if (!info)
829  info = PassInfo::lookup(getPassName());
830  if (!info)
831  return emitDefiniteFailure()
832  << "unknown pass or pass pipeline: " << getPassName();
833 
834  // Create pass manager and add the pass or pass pipeline.
835  PassManager pm(getContext());
836  if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
837  emitError(msg);
838  return failure();
839  }))) {
840  return emitDefiniteFailure()
841  << "failed to add pass or pass pipeline to pipeline: "
842  << getPassName();
843  }
844 
845  auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
846  for (Operation *target : targets) {
847  // Make sure that this transform is not applied to itself. Modifying the
848  // transform IR while it is being interpreted is generally dangerous. Even
849  // more so when applying passes because they may perform a wide range of IR
850  // modifications.
851  DiagnosedSilenceableFailure payloadCheck =
853  if (!payloadCheck.succeeded())
854  return payloadCheck;
855 
856  // Run the pass or pass pipeline on the current target operation.
857  if (failed(pm.run(target))) {
858  auto diag = emitSilenceableError() << "pass pipeline failed";
859  diag.attachNote(target->getLoc()) << "target op";
860  return diag;
861  }
862  }
863 
864  // The applied pass will have directly modified the payload IR(s).
865  results.set(llvm::cast<OpResult>(getResult()), targets);
867 }
868 
870  OpAsmParser &parser, DictionaryAttr &options,
871  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
872  // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
873  SmallVector<NamedAttribute> keyValuePairs;
874  size_t dynamicOptionsIdx = 0;
875 
876  // Helper for allowing parsing of option values which can be of the form:
877  // - a normal attribute
878  // - an operand (which would be converted to an attr referring to the operand)
879  // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
880  std::function<ParseResult(Attribute &)> parseValue =
881  [&](Attribute &valueAttr) -> ParseResult {
882  // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
883  if (succeeded(parser.parseOptionalLSquare())) {
885 
886  // Recursively parse the array's elements, which might be operands.
887  if (parser.parseCommaSeparatedList(
889  [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
890  " in options dictionary") ||
891  parser.parseRSquare())
892  return failure(); // NB: Attempted parse should've output error message.
893 
894  valueAttr = ArrayAttr::get(parser.getContext(), attrs);
895 
896  return success();
897  }
898 
899  // Parse the value, which can be either an attribute or an operand.
900  OptionalParseResult parsedValueAttr =
901  parser.parseOptionalAttribute(valueAttr);
902  if (!parsedValueAttr.has_value()) {
904  ParseResult parsedOperand = parser.parseOperand(operand);
905  if (failed(parsedOperand))
906  return failure(); // NB: Attempted parse should've output error message.
907  // To make use of the operand, we need to store it in the options dict.
908  // As SSA-values cannot occur in attributes, what we do instead is store
909  // an attribute in its place that contains the index of the param-operand,
910  // so that an attr-value associated to the param can be resolved later on.
911  dynamicOptions.push_back(operand);
912  auto wrappedIndex = IntegerAttr::get(
913  IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
914  valueAttr =
915  transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
916  } else if (failed(parsedValueAttr.value())) {
917  return failure(); // NB: Attempted parse should have output error message.
918  } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
919  return parser.emitError(parser.getCurrentLocation())
920  << "the param_operand attribute is a marker reserved for "
921  << "indicating a value will be passed via params and is only used "
922  << "in the generic print format";
923  }
924 
925  return success();
926  };
927 
928  // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
929  // string and `value` looks like either an attribute or an operand-in-an-attr.
930  std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
931  std::string key;
932  Attribute valueAttr;
933 
934  if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
935  return parser.emitError(parser.getCurrentLocation())
936  << "expected key to either be an identifier or a string";
937 
938  if (failed(parser.parseEqual()))
939  return parser.emitError(parser.getCurrentLocation())
940  << "expected '=' after key in key-value pair";
941 
942  if (failed(parseValue(valueAttr)))
943  return parser.emitError(parser.getCurrentLocation())
944  << "expected a valid attribute or operand as value associated "
945  << "to key '" << key << "'";
946 
947  keyValuePairs.push_back(NamedAttribute(key, valueAttr));
948 
949  return success();
950  };
951 
954  " in options dictionary"))
955  return failure(); // NB: Attempted parse should have output error message.
956 
957  if (DictionaryAttr::findDuplicate(
958  keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
959  .has_value())
960  return parser.emitError(parser.getCurrentLocation())
961  << "duplicate keys found in options dictionary";
962 
963  options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
964 
965  return success();
966 }
967 
969  Operation *op,
970  DictionaryAttr options,
971  ValueRange dynamicOptions) {
972  if (options.empty())
973  return;
974 
975  std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
976  if (auto paramOperandAttr =
977  dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
978  // Resolve index of param-operand to its actual SSA-value and print that.
979  printer.printOperand(
980  dynamicOptions[paramOperandAttr.getIndex().getInt()]);
981  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
982  // This case is so that ArrayAttr-contained operands are pretty-printed.
983  printer << "[";
984  llvm::interleaveComma(arrayAttr, printer, printOptionValue);
985  printer << "]";
986  } else {
987  printer.printAttribute(valueAttr);
988  }
989  };
990 
991  printer << "{";
992  llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
993  printer << namedAttribute.getName();
994  printer << " = ";
995  printOptionValue(namedAttribute.getValue());
996  });
997  printer << "}";
998 }
999 
1001  // Check that there is a one-to-one correspondence between param operands
1002  // and references to dynamic options in the options dictionary.
1003 
1004  auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
1005 
1006  // Helper for option values to mark seen operands as having been seen (once).
1007  std::function<LogicalResult(Attribute)> checkOptionValue =
1008  [&](Attribute valueAttr) -> LogicalResult {
1009  if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1010  int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1011  if (dynamicOptionIdx < 0 ||
1012  dynamicOptionIdx >= static_cast<int64_t>(dynamicOptions.size()))
1013  return emitOpError()
1014  << "dynamic option index " << dynamicOptionIdx
1015  << " is out of bounds for the number of dynamic options: "
1016  << dynamicOptions.size();
1017  if (dynamicOptions[dynamicOptionIdx] == nullptr)
1018  return emitOpError() << "dynamic option index " << dynamicOptionIdx
1019  << " is already used in options";
1020  dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1021  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1022  // Recurse into ArrayAttrs as they may contain references to operands.
1023  for (auto eltAttr : arrayAttr)
1024  if (failed(checkOptionValue(eltAttr)))
1025  return failure();
1026  }
1027  return success();
1028  };
1029 
1030  for (NamedAttribute namedAttr : getOptions())
1031  if (failed(checkOptionValue(namedAttr.getValue())))
1032  return failure();
1033 
1034  // All dynamicOptions-params seen in the dict will have been set to null.
1035  for (Value dynamicOption : dynamicOptions)
1036  if (dynamicOption)
1037  return emitOpError() << "a param operand does not have a corresponding "
1038  << "param_operand attr in the options dict";
1039 
1040  return success();
1041 }
1042 
1043 //===----------------------------------------------------------------------===//
1044 // CastOp
1045 //===----------------------------------------------------------------------===//
1046 
1048 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1049  Operation *target, ApplyToEachResultList &results,
1050  transform::TransformState &state) {
1051  results.push_back(target);
1053 }
1054 
1055 void transform::CastOp::getEffects(
1056  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1057  onlyReadsPayload(effects);
1058  onlyReadsHandle(getInputMutable(), effects);
1059  producesHandle(getOperation()->getOpResults(), effects);
1060 }
1061 
1062 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1063  assert(inputs.size() == 1 && "expected one input");
1064  assert(outputs.size() == 1 && "expected one output");
1065  return llvm::all_of(
1066  std::initializer_list<Type>{inputs.front(), outputs.front()},
1067  llvm::IsaPred<transform::TransformHandleTypeInterface>);
1068 }
1069 
1070 //===----------------------------------------------------------------------===//
1071 // CollectMatchingOp
1072 //===----------------------------------------------------------------------===//
1073 
1074 /// Applies matcher operations from the given `block` using
1075 /// `blockArgumentMapping` to initialize block arguments. Updates `state`
1076 /// accordingly. If any of the matcher produces a silenceable failure, discards
1077 /// it (printing the content to the debug output stream) and returns failure. If
1078 /// any of the matchers produces a definite failure, reports it and returns
1079 /// failure. If all matchers in the block succeed, populates `mappings` with the
1080 /// payload entities associated with the block terminator operands. Note that
1081 /// `mappings` will be cleared before that.
1084  ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1086  SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
1087  assert(block.getParent() && "cannot match using a detached block");
1088  auto matchScope = state.make_region_scope(*block.getParent());
1089  if (failed(
1090  state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
1092 
1093  for (Operation &match : block.without_terminator()) {
1094  if (!isa<transform::MatchOpInterface>(match)) {
1095  return emitDefiniteFailure(match.getLoc())
1096  << "expected operations in the match part to "
1097  "implement MatchOpInterface";
1098  }
1100  state.applyTransform(cast<transform::TransformOpInterface>(match));
1101  if (diag.succeeded())
1102  continue;
1103 
1104  return diag;
1105  }
1106 
1107  // Remember the values mapped to the terminator operands so we can
1108  // forward them to the action.
1109  ValueRange yieldedValues = block.getTerminator()->getOperands();
1110  // Our contract with the caller is that the mappings will contain only the
1111  // newly mapped values, clear the rest.
1112  mappings.clear();
1113  transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1115 }
1116 
1117 /// Returns `true` if both types implement one of the interfaces provided as
1118 /// template parameters.
1119 template <typename... Tys>
1120 static bool implementSameInterface(Type t1, Type t2) {
1121  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1122 }
1123 
1124 /// Returns `true` if both types implement one of the transform dialect
1125 /// interfaces.
1127  return implementSameInterface<transform::TransformHandleTypeInterface,
1128  transform::TransformParamTypeInterface,
1129  transform::TransformValueHandleTypeInterface>(
1130  t1, t2);
1131 }
1132 
1133 //===----------------------------------------------------------------------===//
1134 // CollectMatchingOp
1135 //===----------------------------------------------------------------------===//
1136 
1138 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1139  transform::TransformResults &results,
1140  transform::TransformState &state) {
1141  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
1142  getOperation(), getMatcher());
1143  if (matcher.isExternal()) {
1144  return emitDefiniteFailure()
1145  << "unresolved external symbol " << getMatcher();
1146  }
1147 
1148  SmallVector<SmallVector<MappedValue>, 2> rawResults;
1149  rawResults.resize(getOperation()->getNumResults());
1150  std::optional<DiagnosedSilenceableFailure> maybeFailure;
1151  for (Operation *root : state.getPayloadOps(getRoot())) {
1152  WalkResult walkResult = root->walk([&](Operation *op) {
1153  LDBG(DEBUG_TYPE_MATCHER, 1)
1154  << "matching "
1155  << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1156  << " @" << op;
1157 
1158  // Try matching.
1160  SmallVector<transform::MappedValue> inputMapping({op});
1162  matcher.getFunctionBody().front(),
1163  ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1164  mappings);
1165  if (diag.isDefiniteFailure())
1166  return WalkResult::interrupt();
1167  if (diag.isSilenceableFailure()) {
1168  LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1169  << " failed: " << diag.getMessage();
1170  return WalkResult::advance();
1171  }
1172 
1173  // If succeeded, collect results.
1174  for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1175  if (mapping.size() != 1) {
1176  maybeFailure.emplace(emitSilenceableError()
1177  << "result #" << i << ", associated with "
1178  << mapping.size()
1179  << " payload objects, expected 1");
1180  return WalkResult::interrupt();
1181  }
1182  rawResults[i].push_back(mapping[0]);
1183  }
1184  return WalkResult::advance();
1185  });
1186  if (walkResult.wasInterrupted())
1187  return std::move(*maybeFailure);
1188  assert(!maybeFailure && "failure set but the walk was not interrupted");
1189 
1190  for (auto &&[opResult, rawResult] :
1191  llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1192  results.setMappedValues(opResult, rawResult);
1193  }
1194  }
1196 }
1197 
1198 void transform::CollectMatchingOp::getEffects(
1199  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1200  onlyReadsHandle(getRootMutable(), effects);
1201  producesHandle(getOperation()->getOpResults(), effects);
1202  onlyReadsPayload(effects);
1203 }
1204 
1205 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1206  SymbolTableCollection &symbolTable) {
1207  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1208  symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1209  if (!matcherSymbol ||
1210  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1211  return emitError() << "unresolved matcher symbol " << getMatcher();
1212 
1213  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1214  if (argumentTypes.size() != 1 ||
1215  !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1216  return emitError()
1217  << "expected the matcher to take one operation handle argument";
1218  }
1219  if (!matcherSymbol.getArgAttr(
1220  0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1221  return emitError() << "expected the matcher argument to be marked readonly";
1222  }
1223 
1224  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1225  if (resultTypes.size() != getOperation()->getNumResults()) {
1226  return emitError()
1227  << "expected the matcher to yield as many values as op has results ("
1228  << getOperation()->getNumResults() << "), got "
1229  << resultTypes.size();
1230  }
1231 
1232  for (auto &&[i, matcherType, resultType] :
1233  llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1234  if (implementSameTransformInterface(matcherType, resultType))
1235  continue;
1236 
1237  return emitError()
1238  << "mismatching type interfaces for matcher result and op result #"
1239  << i;
1240  }
1241 
1242  return success();
1243 }
1244 
1245 //===----------------------------------------------------------------------===//
1246 // ForeachMatchOp
1247 //===----------------------------------------------------------------------===//
1248 
1249 // This is fine because nothing is actually consumed by this op.
1250 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1251 
1253 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1254  transform::TransformResults &results,
1255  transform::TransformState &state) {
1257  matchActionPairs;
1258  matchActionPairs.reserve(getMatchers().size());
1259  SymbolTableCollection symbolTable;
1260  for (auto &&[matcher, action] :
1261  llvm::zip_equal(getMatchers(), getActions())) {
1262  auto matcherSymbol =
1263  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1264  getOperation(), cast<SymbolRefAttr>(matcher));
1265  auto actionSymbol =
1266  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1267  getOperation(), cast<SymbolRefAttr>(action));
1268  assert(matcherSymbol && actionSymbol &&
1269  "unresolved symbols not caught by the verifier");
1270 
1271  if (matcherSymbol.isExternal())
1272  return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1273  if (actionSymbol.isExternal())
1274  return emitDefiniteFailure() << "unresolved external symbol " << action;
1275 
1276  matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1277  }
1278 
1279  DiagnosedSilenceableFailure overallDiag =
1281 
1282  SmallVector<SmallVector<MappedValue>> matchInputMapping;
1283  SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1284  SmallVector<SmallVector<MappedValue>> actionResultMapping;
1285  // Explicitly add the mapping for the first block argument (the op being
1286  // matched).
1287  matchInputMapping.emplace_back();
1288  transform::detail::prepareValueMappings(matchInputMapping,
1289  getForwardedInputs(), state);
1290  SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1291  actionResultMapping.resize(getForwardedOutputs().size());
1292 
1293  for (Operation *root : state.getPayloadOps(getRoot())) {
1294  WalkResult walkResult = root->walk([&](Operation *op) {
1295  // If getRestrictRoot is not present, skip over the root op itself so we
1296  // don't invalidate it.
1297  if (!getRestrictRoot() && op == root)
1298  return WalkResult::advance();
1299 
1300  LDBG(DEBUG_TYPE_MATCHER, 1)
1301  << "matching "
1302  << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1303  << " @" << op;
1304 
1305  firstMatchArgument.clear();
1306  firstMatchArgument.push_back(op);
1307 
1308  // Try all the match/action pairs until the first successful match.
1309  for (auto [matcher, action] : matchActionPairs) {
1311  matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1312  state, matchOutputMapping);
1313  if (diag.isDefiniteFailure())
1314  return WalkResult::interrupt();
1315  if (diag.isSilenceableFailure()) {
1316  LDBG(DEBUG_TYPE_MATCHER, 1) << "matcher " << matcher.getName()
1317  << " failed: " << diag.getMessage();
1318  continue;
1319  }
1320 
1321  auto scope = state.make_region_scope(action.getFunctionBody());
1322  if (failed(state.mapBlockArguments(
1323  action.getFunctionBody().front().getArguments(),
1324  matchOutputMapping))) {
1325  return WalkResult::interrupt();
1326  }
1327 
1328  for (Operation &transform :
1329  action.getFunctionBody().front().without_terminator()) {
1331  state.applyTransform(cast<TransformOpInterface>(transform));
1332  if (result.isDefiniteFailure())
1333  return WalkResult::interrupt();
1334  if (result.isSilenceableFailure()) {
1335  if (overallDiag.succeeded()) {
1336  overallDiag = emitSilenceableError() << "actions failed";
1337  }
1338  overallDiag.attachNote(action->getLoc())
1339  << "failed action: " << result.getMessage();
1340  overallDiag.attachNote(op->getLoc())
1341  << "when applied to this matching payload";
1342  (void)result.silence();
1343  continue;
1344  }
1345  }
1347  MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1348  action.getFunctionBody().front().getTerminator()->getOperands(),
1349  state, getFlattenResults()))) {
1351  << "action @" << action.getName()
1352  << " has results associated with multiple payload entities, "
1353  "but flattening was not requested";
1354  return WalkResult::interrupt();
1355  }
1356  break;
1357  }
1358  return WalkResult::advance();
1359  });
1360  if (walkResult.wasInterrupted())
1362  }
1363 
1364  // The root operation should not have been affected, so we can just reassign
1365  // the payload to the result. Note that we need to consume the root handle to
1366  // make sure any handles to operations inside, that could have been affected
1367  // by actions, are invalidated.
1368  results.set(llvm::cast<OpResult>(getUpdated()),
1369  state.getPayloadOps(getRoot()));
1370  for (auto &&[result, mapping] :
1371  llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1372  results.setMappedValues(result, mapping);
1373  }
1374  return overallDiag;
1375 }
1376 
1377 void transform::ForeachMatchOp::getAsmResultNames(
1378  OpAsmSetValueNameFn setNameFn) {
1379  setNameFn(getUpdated(), "updated_root");
1380  for (Value v : getForwardedOutputs()) {
1381  setNameFn(v, "yielded");
1382  }
1383 }
1384 
1385 void transform::ForeachMatchOp::getEffects(
1386  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1387  // Bail if invalid.
1388  if (getOperation()->getNumOperands() < 1 ||
1389  getOperation()->getNumResults() < 1) {
1390  return modifiesPayload(effects);
1391  }
1392 
1393  consumesHandle(getRootMutable(), effects);
1394  onlyReadsHandle(getForwardedInputsMutable(), effects);
1395  producesHandle(getOperation()->getOpResults(), effects);
1396  modifiesPayload(effects);
1397 }
1398 
1399 /// Parses the comma-separated list of symbol reference pairs of the format
1400 /// `@matcher -> @action`.
1401 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1402  ArrayAttr &matchers,
1403  ArrayAttr &actions) {
1404  StringAttr matcher;
1405  StringAttr action;
1406  SmallVector<Attribute> matcherList;
1407  SmallVector<Attribute> actionList;
1408  do {
1409  if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1410  parser.parseSymbolName(action)) {
1411  return failure();
1412  }
1413  matcherList.push_back(SymbolRefAttr::get(matcher));
1414  actionList.push_back(SymbolRefAttr::get(action));
1415  } while (parser.parseOptionalComma().succeeded());
1416 
1417  matchers = parser.getBuilder().getArrayAttr(matcherList);
1418  actions = parser.getBuilder().getArrayAttr(actionList);
1419  return success();
1420 }
1421 
1422 /// Prints the comma-separated list of symbol reference pairs of the format
1423 /// `@matcher -> @action`.
1425  ArrayAttr matchers, ArrayAttr actions) {
1426  printer.increaseIndent();
1427  printer.increaseIndent();
1428  for (auto &&[matcher, action, idx] : llvm::zip_equal(
1429  matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1430  printer.printNewline();
1431  printer << cast<SymbolRefAttr>(matcher) << " -> "
1432  << cast<SymbolRefAttr>(action);
1433  if (idx != matchers.size() - 1)
1434  printer << ", ";
1435  }
1436  printer.decreaseIndent();
1437  printer.decreaseIndent();
1438 }
1439 
1440 LogicalResult transform::ForeachMatchOp::verify() {
1441  if (getMatchers().size() != getActions().size())
1442  return emitOpError() << "expected the same number of matchers and actions";
1443  if (getMatchers().empty())
1444  return emitOpError() << "expected at least one match/action pair";
1445 
1446  llvm::SmallPtrSet<Attribute, 8> matcherNames;
1447  for (Attribute name : getMatchers()) {
1448  if (matcherNames.insert(name).second)
1449  continue;
1450  emitWarning() << "matcher " << name
1451  << " is used more than once, only the first match will apply";
1452  }
1453 
1454  return success();
1455 }
1456 
1457 /// Checks that the attributes of the function-like operation have correct
1458 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
1459 /// annotations being present even if they can be inferred from the body.
1461 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1462  bool alsoVerifyInternal = false) {
1463  auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1464  llvm::SmallDenseSet<unsigned> consumedArguments;
1465  if (!op.isExternal()) {
1466  transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1467  consumedArguments);
1468  }
1469  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1470  bool isConsumed =
1471  op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1472  nullptr;
1473  bool isReadOnly =
1474  op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1475  nullptr;
1476  if (isConsumed && isReadOnly) {
1477  return transformOp.emitSilenceableError()
1478  << "argument #" << i << " cannot be both readonly and consumed";
1479  }
1480  if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1481  return transformOp.emitSilenceableError()
1482  << "must provide consumed/readonly status for arguments of "
1483  "external or called ops";
1484  }
1485  if (op.isExternal())
1486  continue;
1487 
1488  if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1489  return transformOp.emitSilenceableError()
1490  << "argument #" << i
1491  << " is consumed in the body but is not marked as such";
1492  }
1493  if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1494  // Cannot use op.emitWarning() here as it would attempt to verify the op
1495  // before printing, resulting in infinite recursion.
1496  emitWarning(op->getLoc())
1497  << "op argument #" << i
1498  << " is not consumed in the body but is marked as consumed";
1499  }
1500  }
1502 }
1503 
1504 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1505  SymbolTableCollection &symbolTable) {
1506  assert(getMatchers().size() == getActions().size());
1507  auto consumedAttr =
1508  StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1509  for (auto &&[matcher, action] :
1510  llvm::zip_equal(getMatchers(), getActions())) {
1511  // Presence and typing.
1512  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1513  symbolTable.lookupNearestSymbolFrom(getOperation(),
1514  cast<SymbolRefAttr>(matcher)));
1515  auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1516  symbolTable.lookupNearestSymbolFrom(getOperation(),
1517  cast<SymbolRefAttr>(action)));
1518  if (!matcherSymbol ||
1519  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1520  return emitError() << "unresolved matcher symbol " << matcher;
1521  if (!actionSymbol ||
1522  !isa<TransformOpInterface>(actionSymbol.getOperation()))
1523  return emitError() << "unresolved action symbol " << action;
1524 
1525  if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1526  /*emitWarnings=*/false,
1527  /*alsoVerifyInternal=*/true)
1528  .checkAndReport())) {
1529  return failure();
1530  }
1532  /*emitWarnings=*/false,
1533  /*alsoVerifyInternal=*/true)
1534  .checkAndReport())) {
1535  return failure();
1536  }
1537 
1538  // Input -> matcher forwarding.
1539  TypeRange operandTypes = getOperandTypes();
1540  TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1541  if (operandTypes.size() != matcherArguments.size()) {
1543  emitError() << "the number of operands (" << operandTypes.size()
1544  << ") doesn't match the number of matcher arguments ("
1545  << matcherArguments.size() << ") for " << matcher;
1546  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1547  return diag;
1548  }
1549  for (auto &&[i, operand, argument] :
1550  llvm::enumerate(operandTypes, matcherArguments)) {
1551  if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1553  emitOpError()
1554  << "does not expect matcher symbol to consume its operand #" << i;
1555  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1556  return diag;
1557  }
1558 
1559  if (implementSameTransformInterface(operand, argument))
1560  continue;
1561 
1563  emitError()
1564  << "mismatching type interfaces for operand and matcher argument #"
1565  << i << " of matcher " << matcher;
1566  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1567  return diag;
1568  }
1569 
1570  // Matcher -> action forwarding.
1571  TypeRange matcherResults = matcherSymbol.getResultTypes();
1572  TypeRange actionArguments = actionSymbol.getArgumentTypes();
1573  if (matcherResults.size() != actionArguments.size()) {
1574  return emitError() << "mismatching number of matcher results and "
1575  "action arguments between "
1576  << matcher << " (" << matcherResults.size() << ") and "
1577  << action << " (" << actionArguments.size() << ")";
1578  }
1579  for (auto &&[i, matcherType, actionType] :
1580  llvm::enumerate(matcherResults, actionArguments)) {
1581  if (implementSameTransformInterface(matcherType, actionType))
1582  continue;
1583 
1584  return emitError() << "mismatching type interfaces for matcher result "
1585  "and action argument #"
1586  << i << "of matcher " << matcher << " and action "
1587  << action;
1588  }
1589 
1590  // Action -> result forwarding.
1591  TypeRange actionResults = actionSymbol.getResultTypes();
1592  auto resultTypes = TypeRange(getResultTypes()).drop_front();
1593  if (actionResults.size() != resultTypes.size()) {
1595  emitError() << "the number of action results ("
1596  << actionResults.size() << ") for " << action
1597  << " doesn't match the number of extra op results ("
1598  << resultTypes.size() << ")";
1599  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1600  return diag;
1601  }
1602  for (auto &&[i, resultType, actionType] :
1603  llvm::enumerate(resultTypes, actionResults)) {
1604  if (implementSameTransformInterface(resultType, actionType))
1605  continue;
1606 
1608  emitError() << "mismatching type interfaces for action result #" << i
1609  << " of action " << action << " and op result";
1610  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1611  return diag;
1612  }
1613  }
1614  return success();
1615 }
1616 
1617 //===----------------------------------------------------------------------===//
1618 // ForeachOp
1619 //===----------------------------------------------------------------------===//
1620 
1622 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1623  transform::TransformResults &results,
1624  transform::TransformState &state) {
1625  // We store the payloads before executing the body as ops may be removed from
1626  // the mapping by the TrackingRewriter while iteration is in progress.
1628  detail::prepareValueMappings(payloads, getTargets(), state);
1629  size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1630  bool withZipShortest = getWithZipShortest();
1631 
1632  // In case of `zip_shortest`, set the number of iterations to the
1633  // smallest payload in the targets.
1634  if (withZipShortest) {
1635  numIterations =
1636  llvm::min_element(payloads, [&](const SmallVector<MappedValue> &a,
1637  const SmallVector<MappedValue> &b) {
1638  return a.size() < b.size();
1639  })->size();
1640 
1641  for (auto &payload : payloads)
1642  payload.resize(numIterations);
1643  }
1644 
1645  // As we will be "zipping" over them, check all payloads have the same size.
1646  // `zip_shortest` adjusts all payloads to the same size, so skip this check
1647  // when true.
1648  for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1649  argIdx++) {
1650  if (payloads[argIdx].size() != numIterations) {
1651  return emitSilenceableError()
1652  << "prior targets' payload size (" << numIterations
1653  << ") differs from payload size (" << payloads[argIdx].size()
1654  << ") of target " << getTargets()[argIdx];
1655  }
1656  }
1657 
1658  // Start iterating, indexing into payloads to obtain the right arguments to
1659  // call the body with - each slice of payloads at the same argument index
1660  // corresponding to a tuple to use as the body's block arguments.
1661  ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1662  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1663  for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1664  auto scope = state.make_region_scope(getBody());
1665  // Set up arguments to the region's block.
1666  for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1667  MappedValue argument = payloads[argIdx][iterIdx];
1668  // Note that each blockArg's handle gets associated with just a single
1669  // element from the corresponding target's payload.
1670  if (failed(state.mapBlockArgument(blockArg, {argument})))
1672  }
1673 
1674  // Execute loop body.
1675  for (Operation &transform : getBody().front().without_terminator()) {
1676  DiagnosedSilenceableFailure result = state.applyTransform(
1677  llvm::cast<transform::TransformOpInterface>(transform));
1678  if (!result.succeeded())
1679  return result;
1680  }
1681 
1682  // Append yielded payloads to corresponding results from prior iterations.
1683  OperandRange yieldOperands = getYieldOp().getOperands();
1684  for (auto &&[result, yieldOperand, resTuple] :
1685  llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1686  // NB: each iteration we add any number of ops/vals/params to a result.
1687  if (isa<TransformHandleTypeInterface>(result.getType()))
1688  llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1689  else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1690  llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1691  else if (isa<TransformParamTypeInterface>(result.getType()))
1692  llvm::append_range(resTuple, state.getParams(yieldOperand));
1693  else
1694  assert(false && "unhandled handle type");
1695  }
1696 
1697  // Associate the accumulated result payloads to the op's actual results.
1698  for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1699  results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1700 
1702 }
1703 
1704 void transform::ForeachOp::getEffects(
1705  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1706  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1707  // arity errors, this method might get called before/in absence of `verify()`.
1708  for (auto &&[target, blockArg] :
1709  llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1710  BlockArgument blockArgument = blockArg;
1711  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1712  return isHandleConsumed(blockArgument,
1713  cast<TransformOpInterface>(&op));
1714  })) {
1715  consumesHandle(target, effects);
1716  } else {
1717  onlyReadsHandle(target, effects);
1718  }
1719  }
1720 
1721  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1722  return doesModifyPayload(cast<TransformOpInterface>(&op));
1723  })) {
1724  modifiesPayload(effects);
1725  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1726  return doesReadPayload(cast<TransformOpInterface>(&op));
1727  })) {
1728  onlyReadsPayload(effects);
1729  }
1730 
1731  producesHandle(getOperation()->getOpResults(), effects);
1732 }
1733 
1734 void transform::ForeachOp::getSuccessorRegions(
1735  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1736  Region *bodyRegion = &getBody();
1737  if (point.isParent()) {
1738  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1739  return;
1740  }
1741 
1742  // Branch back to the region or the parent.
1743  assert(point == getBody() && "unexpected region index");
1744  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1745  regions.emplace_back();
1746 }
1747 
1749 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1750  // Each block argument handle is mapped to a subset (one op to be precise)
1751  // of the payload of the corresponding `targets` operand of ForeachOp.
1752  assert(point == getBody() && "unexpected region index");
1753  return getOperation()->getOperands();
1754 }
1755 
1756 transform::YieldOp transform::ForeachOp::getYieldOp() {
1757  return cast<transform::YieldOp>(getBody().front().getTerminator());
1758 }
1759 
1760 LogicalResult transform::ForeachOp::verify() {
1761  for (auto [targetOpt, bodyArgOpt] :
1762  llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1763  if (!targetOpt || !bodyArgOpt)
1764  return emitOpError() << "expects the same number of targets as the body "
1765  "has block arguments";
1766  if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1767  return emitOpError(
1768  "expects co-indexed targets and the body's "
1769  "block arguments to have the same op/value/param type");
1770  }
1771 
1772  for (auto [resultOpt, yieldOperandOpt] :
1773  llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1774  if (!resultOpt || !yieldOperandOpt)
1775  return emitOpError() << "expects the same number of results as the "
1776  "yield terminator has operands";
1777  if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1778  return emitOpError("expects co-indexed results and yield "
1779  "operands to have the same op/value/param type");
1780  }
1781 
1782  return success();
1783 }
1784 
1785 //===----------------------------------------------------------------------===//
1786 // GetParentOp
1787 //===----------------------------------------------------------------------===//
1788 
1790 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1791  transform::TransformResults &results,
1792  transform::TransformState &state) {
1793  SmallVector<Operation *> parents;
1794  DenseSet<Operation *> resultSet;
1795  for (Operation *target : state.getPayloadOps(getTarget())) {
1796  Operation *parent = target;
1797  for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1798  parent = parent->getParentOp();
1799  while (parent) {
1800  bool checkIsolatedFromAbove =
1801  !getIsolatedFromAbove() ||
1803  bool checkOpName = !getOpName().has_value() ||
1804  parent->getName().getStringRef() == *getOpName();
1805  if (checkIsolatedFromAbove && checkOpName)
1806  break;
1807  parent = parent->getParentOp();
1808  }
1809  if (!parent) {
1810  if (getAllowEmptyResults()) {
1811  results.set(llvm::cast<OpResult>(getResult()), parents);
1813  }
1815  emitSilenceableError()
1816  << "could not find a parent op that matches all requirements";
1817  diag.attachNote(target->getLoc()) << "target op";
1818  return diag;
1819  }
1820  }
1821  if (getDeduplicate()) {
1822  if (resultSet.insert(parent).second)
1823  parents.push_back(parent);
1824  } else {
1825  parents.push_back(parent);
1826  }
1827  }
1828  results.set(llvm::cast<OpResult>(getResult()), parents);
1830 }
1831 
1832 //===----------------------------------------------------------------------===//
1833 // GetConsumersOfResult
1834 //===----------------------------------------------------------------------===//
1835 
1837 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1838  transform::TransformResults &results,
1839  transform::TransformState &state) {
1840  int64_t resultNumber = getResultNumber();
1841  auto payloadOps = state.getPayloadOps(getTarget());
1842  if (std::empty(payloadOps)) {
1843  results.set(cast<OpResult>(getResult()), {});
1845  }
1846  if (!llvm::hasSingleElement(payloadOps))
1847  return emitDefiniteFailure()
1848  << "handle must be mapped to exactly one payload op";
1849 
1850  Operation *target = *payloadOps.begin();
1851  if (target->getNumResults() <= resultNumber)
1852  return emitDefiniteFailure() << "result number overflow";
1853  results.set(llvm::cast<OpResult>(getResult()),
1854  llvm::to_vector(target->getResult(resultNumber).getUsers()));
1856 }
1857 
1858 //===----------------------------------------------------------------------===//
1859 // GetDefiningOp
1860 //===----------------------------------------------------------------------===//
1861 
1863 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1864  transform::TransformResults &results,
1865  transform::TransformState &state) {
1866  SmallVector<Operation *> definingOps;
1867  for (Value v : state.getPayloadValues(getTarget())) {
1868  if (llvm::isa<BlockArgument>(v)) {
1870  emitSilenceableError() << "cannot get defining op of block argument";
1871  diag.attachNote(v.getLoc()) << "target value";
1872  return diag;
1873  }
1874  definingOps.push_back(v.getDefiningOp());
1875  }
1876  results.set(llvm::cast<OpResult>(getResult()), definingOps);
1878 }
1879 
1880 //===----------------------------------------------------------------------===//
1881 // GetProducerOfOperand
1882 //===----------------------------------------------------------------------===//
1883 
1885 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1886  transform::TransformResults &results,
1887  transform::TransformState &state) {
1888  int64_t operandNumber = getOperandNumber();
1889  SmallVector<Operation *> producers;
1890  for (Operation *target : state.getPayloadOps(getTarget())) {
1891  Operation *producer =
1892  target->getNumOperands() <= operandNumber
1893  ? nullptr
1894  : target->getOperand(operandNumber).getDefiningOp();
1895  if (!producer) {
1897  emitSilenceableError()
1898  << "could not find a producer for operand number: " << operandNumber
1899  << " of " << *target;
1900  diag.attachNote(target->getLoc()) << "target op";
1901  return diag;
1902  }
1903  producers.push_back(producer);
1904  }
1905  results.set(llvm::cast<OpResult>(getResult()), producers);
1907 }
1908 
1909 //===----------------------------------------------------------------------===//
1910 // GetOperandOp
1911 //===----------------------------------------------------------------------===//
1912 
1914 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1915  transform::TransformResults &results,
1916  transform::TransformState &state) {
1917  SmallVector<Value> operands;
1918  for (Operation *target : state.getPayloadOps(getTarget())) {
1919  SmallVector<int64_t> operandPositions;
1921  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1922  target->getNumOperands(), operandPositions);
1923  if (diag.isSilenceableFailure()) {
1924  diag.attachNote(target->getLoc())
1925  << "while considering positions of this payload operation";
1926  return diag;
1927  }
1928  llvm::append_range(operands,
1929  llvm::map_range(operandPositions, [&](int64_t pos) {
1930  return target->getOperand(pos);
1931  }));
1932  }
1933  results.setValues(cast<OpResult>(getResult()), operands);
1935 }
1936 
1937 LogicalResult transform::GetOperandOp::verify() {
1938  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1939  getIsInverted(), getIsAll());
1940 }
1941 
1942 //===----------------------------------------------------------------------===//
1943 // GetResultOp
1944 //===----------------------------------------------------------------------===//
1945 
1947 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1948  transform::TransformResults &results,
1949  transform::TransformState &state) {
1950  SmallVector<Value> opResults;
1951  for (Operation *target : state.getPayloadOps(getTarget())) {
1952  SmallVector<int64_t> resultPositions;
1954  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1955  target->getNumResults(), resultPositions);
1956  if (diag.isSilenceableFailure()) {
1957  diag.attachNote(target->getLoc())
1958  << "while considering positions of this payload operation";
1959  return diag;
1960  }
1961  llvm::append_range(opResults,
1962  llvm::map_range(resultPositions, [&](int64_t pos) {
1963  return target->getResult(pos);
1964  }));
1965  }
1966  results.setValues(cast<OpResult>(getResult()), opResults);
1968 }
1969 
1970 LogicalResult transform::GetResultOp::verify() {
1971  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1972  getIsInverted(), getIsAll());
1973 }
1974 
1975 //===----------------------------------------------------------------------===//
1976 // GetTypeOp
1977 //===----------------------------------------------------------------------===//
1978 
1979 void transform::GetTypeOp::getEffects(
1980  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1981  onlyReadsHandle(getValueMutable(), effects);
1982  producesHandle(getOperation()->getOpResults(), effects);
1983  onlyReadsPayload(effects);
1984 }
1985 
1987 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1988  transform::TransformResults &results,
1989  transform::TransformState &state) {
1990  SmallVector<Attribute> params;
1991  for (Value value : state.getPayloadValues(getValue())) {
1992  Type type = value.getType();
1993  if (getElemental()) {
1994  if (auto shaped = dyn_cast<ShapedType>(type)) {
1995  type = shaped.getElementType();
1996  }
1997  }
1998  params.push_back(TypeAttr::get(type));
1999  }
2000  results.setParams(cast<OpResult>(getResult()), params);
2002 }
2003 
2004 //===----------------------------------------------------------------------===//
2005 // IncludeOp
2006 //===----------------------------------------------------------------------===//
2007 
2008 /// Applies the transform ops contained in `block`. Maps `results` to the same
2009 /// values as the operands of the block terminator.
2011 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
2013  transform::TransformResults &results) {
2014  // Apply the sequenced ops one by one.
2015  for (Operation &transform : block.without_terminator()) {
2017  state.applyTransform(cast<transform::TransformOpInterface>(transform));
2018  if (result.isDefiniteFailure())
2019  return result;
2020 
2021  if (result.isSilenceableFailure()) {
2022  if (mode == transform::FailurePropagationMode::Propagate) {
2023  // Propagate empty results in case of early exit.
2024  forwardEmptyOperands(&block, state, results);
2025  return result;
2026  }
2027  (void)result.silence();
2028  }
2029  }
2030 
2031  // Forward the operation mapping for values yielded from the sequence to the
2032  // values produced by the sequence op.
2033  transform::detail::forwardTerminatorOperands(&block, state, results);
2035 }
2036 
2038 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2039  transform::TransformResults &results,
2040  transform::TransformState &state) {
2041  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2042  getOperation(), getTarget());
2043  assert(callee && "unverified reference to unknown symbol");
2044 
2045  if (callee.isExternal())
2046  return emitDefiniteFailure() << "unresolved external named sequence";
2047 
2048  // Map operands to block arguments.
2050  detail::prepareValueMappings(mappings, getOperands(), state);
2051  auto scope = state.make_region_scope(callee.getBody());
2052  for (auto &&[arg, map] :
2053  llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2054  if (failed(state.mapBlockArgument(arg, map)))
2056  }
2057 
2059  callee.getBody().front(), getFailurePropagationMode(), state, results);
2060  mappings.clear();
2062  mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2063  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2064  results.setMappedValues(result, mapping);
2065  return result;
2066 }
2067 
2069 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2070 
2071 void transform::IncludeOp::getEffects(
2072  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2073  // Always mark as modifying the payload.
2074  // TODO: a mechanism to annotate effects on payload. Even when all handles are
2075  // only read, the payload may still be modified, so we currently stay on the
2076  // conservative side and always indicate modification. This may prevent some
2077  // code reordering.
2078  modifiesPayload(effects);
2079 
2080  // Results are always produced.
2081  producesHandle(getOperation()->getOpResults(), effects);
2082 
2083  // Adds default effects to operands and results. This will be added if
2084  // preconditions fail so the trait verifier doesn't complain about missing
2085  // effects and the real precondition failure is reported later on.
2086  auto defaultEffects = [&] {
2087  onlyReadsHandle(getOperation()->getOpOperands(), effects);
2088  };
2089 
2090  // Bail if the callee is unknown. This may run as part of the verification
2091  // process before we verified the validity of the callee or of this op.
2092  auto target =
2093  getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2094  if (!target)
2095  return defaultEffects();
2096  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2097  getOperation(), getTarget());
2098  if (!callee)
2099  return defaultEffects();
2100 
2101  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2102  if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2103  consumesHandle(getOperation()->getOpOperand(i), effects);
2104  else if (callee.getArgAttr(i, TransformDialect::kArgReadOnlyAttrName))
2105  onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2106  }
2107 }
2108 
2109 LogicalResult
2110 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2111  // Access through indirection and do additional checking because this may be
2112  // running before the main op verifier.
2113  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2114  if (!targetAttr)
2115  return emitOpError() << "expects a 'target' symbol reference attribute";
2116 
2117  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2118  *this, targetAttr);
2119  if (!target)
2120  return emitOpError() << "does not reference a named transform sequence";
2121 
2122  FunctionType fnType = target.getFunctionType();
2123  if (fnType.getNumInputs() != getNumOperands())
2124  return emitError("incorrect number of operands for callee");
2125 
2126  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2127  if (getOperand(i).getType() != fnType.getInput(i)) {
2128  return emitOpError("operand type mismatch: expected operand type ")
2129  << fnType.getInput(i) << ", but provided "
2130  << getOperand(i).getType() << " for operand number " << i;
2131  }
2132  }
2133 
2134  if (fnType.getNumResults() != getNumResults())
2135  return emitError("incorrect number of results for callee");
2136 
2137  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2138  Type resultType = getResult(i).getType();
2139  Type funcType = fnType.getResult(i);
2140  if (!implementSameTransformInterface(resultType, funcType)) {
2141  return emitOpError() << "type of result #" << i
2142  << " must implement the same transform dialect "
2143  "interface as the corresponding callee result";
2144  }
2145  }
2146 
2148  cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2149  /*alsoVerifyInternal=*/true)
2150  .checkAndReport();
2151 }
2152 
2153 //===----------------------------------------------------------------------===//
2154 // MatchOperationEmptyOp
2155 //===----------------------------------------------------------------------===//
2156 
2157 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2158  ::std::optional<::mlir::Operation *> maybeCurrent,
2160  if (!maybeCurrent.has_value()) {
2161  LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp success";
2163  }
2164  LDBG(DEBUG_TYPE_MATCHER, 1) << "MatchOperationEmptyOp failure";
2165  return emitSilenceableError() << "operation is not empty";
2166 }
2167 
2168 //===----------------------------------------------------------------------===//
2169 // MatchOperationNameOp
2170 //===----------------------------------------------------------------------===//
2171 
2172 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2173  Operation *current, transform::TransformResults &results,
2174  transform::TransformState &state) {
2175  StringRef currentOpName = current->getName().getStringRef();
2176  for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2177  if (acceptedAttr.getValue() == currentOpName)
2179  }
2180  return emitSilenceableError() << "wrong operation name";
2181 }
2182 
2183 //===----------------------------------------------------------------------===//
2184 // MatchParamCmpIOp
2185 //===----------------------------------------------------------------------===//
2186 
2188 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2189  transform::TransformResults &results,
2190  transform::TransformState &state) {
2191  auto signedAPIntAsString = [&](const APInt &value) {
2192  std::string str;
2193  llvm::raw_string_ostream os(str);
2194  value.print(os, /*isSigned=*/true);
2195  return str;
2196  };
2197 
2198  ArrayRef<Attribute> params = state.getParams(getParam());
2199  ArrayRef<Attribute> references = state.getParams(getReference());
2200 
2201  if (params.size() != references.size()) {
2202  return emitSilenceableError()
2203  << "parameters have different payload lengths (" << params.size()
2204  << " vs " << references.size() << ")";
2205  }
2206 
2207  for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2208  auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2209  auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2210  if (!intAttr || !refAttr) {
2211  return emitDefiniteFailure()
2212  << "non-integer parameter value not expected";
2213  }
2214  if (intAttr.getType() != refAttr.getType()) {
2215  return emitDefiniteFailure()
2216  << "mismatching integer attribute types in parameter #" << i;
2217  }
2218  APInt value = intAttr.getValue();
2219  APInt refValue = refAttr.getValue();
2220 
2221  // TODO: this copy will not be necessary in C++20.
2222  int64_t position = i;
2223  auto reportError = [&](StringRef direction) {
2225  emitSilenceableError() << "expected parameter to be " << direction
2226  << " " << signedAPIntAsString(refValue)
2227  << ", got " << signedAPIntAsString(value);
2228  diag.attachNote(getParam().getLoc())
2229  << "value # " << position
2230  << " associated with the parameter defined here";
2231  return diag;
2232  };
2233 
2234  switch (getPredicate()) {
2235  case MatchCmpIPredicate::eq:
2236  if (value.eq(refValue))
2237  break;
2238  return reportError("equal to");
2239  case MatchCmpIPredicate::ne:
2240  if (value.ne(refValue))
2241  break;
2242  return reportError("not equal to");
2243  case MatchCmpIPredicate::lt:
2244  if (value.slt(refValue))
2245  break;
2246  return reportError("less than");
2247  case MatchCmpIPredicate::le:
2248  if (value.sle(refValue))
2249  break;
2250  return reportError("less than or equal to");
2251  case MatchCmpIPredicate::gt:
2252  if (value.sgt(refValue))
2253  break;
2254  return reportError("greater than");
2255  case MatchCmpIPredicate::ge:
2256  if (value.sge(refValue))
2257  break;
2258  return reportError("greater than or equal to");
2259  }
2260  }
2262 }
2263 
2264 void transform::MatchParamCmpIOp::getEffects(
2265  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2266  onlyReadsHandle(getParamMutable(), effects);
2267  onlyReadsHandle(getReferenceMutable(), effects);
2268 }
2269 
2270 //===----------------------------------------------------------------------===//
2271 // ParamConstantOp
2272 //===----------------------------------------------------------------------===//
2273 
2275 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2276  transform::TransformResults &results,
2277  transform::TransformState &state) {
2278  results.setParams(cast<OpResult>(getParam()), {getValue()});
2280 }
2281 
2282 //===----------------------------------------------------------------------===//
2283 // MergeHandlesOp
2284 //===----------------------------------------------------------------------===//
2285 
2287 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2288  transform::TransformResults &results,
2289  transform::TransformState &state) {
2290  ValueRange handles = getHandles();
2291  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2292  SmallVector<Operation *> operations;
2293  for (Value operand : handles)
2294  llvm::append_range(operations, state.getPayloadOps(operand));
2295  if (!getDeduplicate()) {
2296  results.set(llvm::cast<OpResult>(getResult()), operations);
2298  }
2299 
2300  SetVector<Operation *> uniqued(llvm::from_range, operations);
2301  results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2303  }
2304 
2305  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2306  SmallVector<Attribute> attrs;
2307  for (Value attribute : handles)
2308  llvm::append_range(attrs, state.getParams(attribute));
2309  if (!getDeduplicate()) {
2310  results.setParams(cast<OpResult>(getResult()), attrs);
2312  }
2313 
2314  SetVector<Attribute> uniqued(llvm::from_range, attrs);
2315  results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2317  }
2318 
2319  assert(
2320  llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2321  "expected value handle type");
2322  SmallVector<Value> payloadValues;
2323  for (Value value : handles)
2324  llvm::append_range(payloadValues, state.getPayloadValues(value));
2325  if (!getDeduplicate()) {
2326  results.setValues(cast<OpResult>(getResult()), payloadValues);
2328  }
2329 
2330  SetVector<Value> uniqued(llvm::from_range, payloadValues);
2331  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2333 }
2334 
2335 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2336  // Handles may be the same if deduplicating is enabled.
2337  return getDeduplicate();
2338 }
2339 
2340 void transform::MergeHandlesOp::getEffects(
2341  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2342  onlyReadsHandle(getHandlesMutable(), effects);
2343  producesHandle(getOperation()->getOpResults(), effects);
2344 
2345  // There are no effects on the Payload IR as this is only a handle
2346  // manipulation.
2347 }
2348 
2349 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2350  if (getDeduplicate() || getHandles().size() != 1)
2351  return {};
2352 
2353  // If deduplication is not required and there is only one operand, it can be
2354  // used directly instead of merging.
2355  return getHandles().front();
2356 }
2357 
2358 //===----------------------------------------------------------------------===//
2359 // NamedSequenceOp
2360 //===----------------------------------------------------------------------===//
2361 
2363 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2364  transform::TransformResults &results,
2365  transform::TransformState &state) {
2366  if (isExternal())
2367  return emitDefiniteFailure() << "unresolved external named sequence";
2368 
2369  // Map the entry block argument to the list of operations.
2370  // Note: this is the same implementation as PossibleTopLevelTransformOp but
2371  // without attaching the interface / trait since that is tailored to a
2372  // dangling top-level op that does not get "called".
2373  auto scope = state.make_region_scope(getBody());
2375  state, this->getOperation(), getBody())))
2377 
2378  return applySequenceBlock(getBody().front(),
2379  FailurePropagationMode::Propagate, state, results);
2380 }
2381 
2382 void transform::NamedSequenceOp::getEffects(
2383  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2384 
2386  OperationState &result) {
2388  parser, result, /*allowVariadic=*/false,
2389  getFunctionTypeAttrName(result.name),
2390  [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2392  std::string &) { return builder.getFunctionType(inputs, results); },
2393  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2394 }
2395 
2398  printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2399  getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2400  getResAttrsAttrName());
2401 }
2402 
2403 /// Verifies that a symbol function-like transform dialect operation has the
2404 /// signature and the terminator that have conforming types, i.e., types
2405 /// implementing the same transform dialect type interface. If `allowExternal`
2406 /// is set, allow external symbols (declarations) and don't check the terminator
2407 /// as it may not exist.
2409 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2410  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2413  << "cannot be defined inside another transform op";
2414  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2415  return diag;
2416  }
2417 
2418  if (op.isExternal() || op.getFunctionBody().empty()) {
2419  if (allowExternal)
2421 
2422  return emitSilenceableFailure(op) << "cannot be external";
2423  }
2424 
2425  if (op.getFunctionBody().front().empty())
2426  return emitSilenceableFailure(op) << "expected a non-empty body block";
2427 
2428  Operation *terminator = &op.getFunctionBody().front().back();
2429  if (!isa<transform::YieldOp>(terminator)) {
2431  << "expected '"
2432  << transform::YieldOp::getOperationName()
2433  << "' as terminator";
2434  diag.attachNote(terminator->getLoc()) << "terminator";
2435  return diag;
2436  }
2437 
2438  if (terminator->getNumOperands() != op.getResultTypes().size()) {
2439  return emitSilenceableFailure(terminator)
2440  << "expected terminator to have as many operands as the parent op "
2441  "has results";
2442  }
2443  for (auto [i, operandType, resultType] : llvm::zip_equal(
2444  llvm::seq<unsigned>(0, terminator->getNumOperands()),
2445  terminator->getOperands().getType(), op.getResultTypes())) {
2446  if (operandType == resultType)
2447  continue;
2448  return emitSilenceableFailure(terminator)
2449  << "the type of the terminator operand #" << i
2450  << " must match the type of the corresponding parent op result ("
2451  << operandType << " vs " << resultType << ")";
2452  }
2453 
2455 }
2456 
2457 /// Verification of a NamedSequenceOp. This does not report the error
2458 /// immediately, so it can be used to check for op's well-formedness before the
2459 /// verifier runs, e.g., during trait verification.
2461 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2462  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2463  if (!parent->getAttr(
2464  transform::TransformDialect::kWithNamedSequenceAttrName)) {
2467  << "expects the parent symbol table to have the '"
2468  << transform::TransformDialect::kWithNamedSequenceAttrName
2469  << "' attribute";
2470  diag.attachNote(parent->getLoc()) << "symbol table operation";
2471  return diag;
2472  }
2473  }
2474 
2475  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2478  << "cannot be defined inside another transform op";
2479  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2480  return diag;
2481  }
2482 
2483  if (op.isExternal() || op.getBody().empty())
2484  return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2485  emitWarnings);
2486 
2487  if (op.getBody().front().empty())
2488  return emitSilenceableFailure(op) << "expected a non-empty body block";
2489 
2490  Operation *terminator = &op.getBody().front().back();
2491  if (!isa<transform::YieldOp>(terminator)) {
2493  << "expected '"
2494  << transform::YieldOp::getOperationName()
2495  << "' as terminator";
2496  diag.attachNote(terminator->getLoc()) << "terminator";
2497  return diag;
2498  }
2499 
2500  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2501  return emitSilenceableFailure(terminator)
2502  << "expected terminator to have as many operands as the parent op "
2503  "has results";
2504  }
2505  for (auto [i, operandType, resultType] :
2506  llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2507  terminator->getOperands().getType(),
2508  op.getFunctionType().getResults())) {
2509  if (operandType == resultType)
2510  continue;
2511  return emitSilenceableFailure(terminator)
2512  << "the type of the terminator operand #" << i
2513  << " must match the type of the corresponding parent op result ("
2514  << operandType << " vs " << resultType << ")";
2515  }
2516 
2517  auto funcOp = cast<FunctionOpInterface>(*op);
2519  verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2520  if (!diag.succeeded())
2521  return diag;
2522 
2523  return verifyYieldingSingleBlockOp(funcOp,
2524  /*allowExternal=*/true);
2525 }
2526 
2527 LogicalResult transform::NamedSequenceOp::verify() {
2528  // Actual verification happens in a separate function for reusability.
2529  return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2530 }
2531 
2532 template <typename FnTy>
2533 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2534  Type bbArgType, TypeRange extraBindingTypes,
2535  FnTy bodyBuilder) {
2536  SmallVector<Type> types;
2537  types.reserve(1 + extraBindingTypes.size());
2538  types.push_back(bbArgType);
2539  llvm::append_range(types, extraBindingTypes);
2540 
2541  OpBuilder::InsertionGuard guard(builder);
2542  Region *region = state.regions.back().get();
2543  Block *bodyBlock =
2544  builder.createBlock(region, region->begin(), types,
2545  SmallVector<Location>(types.size(), state.location));
2546 
2547  // Populate body.
2548  builder.setInsertionPointToStart(bodyBlock);
2549  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2550  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2551  } else {
2552  bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2553  bodyBlock->getArguments().drop_front());
2554  }
2555 }
2556 
2557 void transform::NamedSequenceOp::build(OpBuilder &builder,
2558  OperationState &state, StringRef symName,
2559  Type rootType, TypeRange resultTypes,
2560  SequenceBodyBuilderFn bodyBuilder,
2562  ArrayRef<DictionaryAttr> argAttrs) {
2563  state.addAttribute(SymbolTable::getSymbolAttrName(),
2564  builder.getStringAttr(symName));
2565  state.addAttribute(getFunctionTypeAttrName(state.name),
2567  rootType, resultTypes)));
2568  state.attributes.append(attrs.begin(), attrs.end());
2569  state.addRegion();
2570 
2571  buildSequenceBody(builder, state, rootType,
2572  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2573 }
2574 
2575 //===----------------------------------------------------------------------===//
2576 // NumAssociationsOp
2577 //===----------------------------------------------------------------------===//
2578 
2580 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2581  transform::TransformResults &results,
2582  transform::TransformState &state) {
2583  size_t numAssociations =
2585  .Case([&](TransformHandleTypeInterface opHandle) {
2586  return llvm::range_size(state.getPayloadOps(getHandle()));
2587  })
2588  .Case([&](TransformValueHandleTypeInterface valueHandle) {
2589  return llvm::range_size(state.getPayloadValues(getHandle()));
2590  })
2591  .Case([&](TransformParamTypeInterface param) {
2592  return llvm::range_size(state.getParams(getHandle()));
2593  })
2594  .DefaultUnreachable("unknown kind of transform dialect type");
2595  results.setParams(cast<OpResult>(getNum()),
2596  rewriter.getI64IntegerAttr(numAssociations));
2598 }
2599 
2600 LogicalResult transform::NumAssociationsOp::verify() {
2601  // Verify that the result type accepts an i64 attribute as payload.
2602  auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2603  return resultType
2604  .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2605  .checkAndReport();
2606 }
2607 
2608 //===----------------------------------------------------------------------===//
2609 // SelectOp
2610 //===----------------------------------------------------------------------===//
2611 
2613 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2614  transform::TransformResults &results,
2615  transform::TransformState &state) {
2616  SmallVector<Operation *> result;
2617  auto payloadOps = state.getPayloadOps(getTarget());
2618  for (Operation *op : payloadOps) {
2619  if (op->getName().getStringRef() == getOpName())
2620  result.push_back(op);
2621  }
2622  results.set(cast<OpResult>(getResult()), result);
2624 }
2625 
2626 //===----------------------------------------------------------------------===//
2627 // SplitHandleOp
2628 //===----------------------------------------------------------------------===//
2629 
2630 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2631  Value target, int64_t numResultHandles) {
2632  result.addOperands(target);
2633  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2634 }
2635 
2637 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2638  transform::TransformResults &results,
2639  transform::TransformState &state) {
2640  int64_t numPayloads =
2642  .Case<TransformHandleTypeInterface>([&](auto x) {
2643  return llvm::range_size(state.getPayloadOps(getHandle()));
2644  })
2645  .Case<TransformValueHandleTypeInterface>([&](auto x) {
2646  return llvm::range_size(state.getPayloadValues(getHandle()));
2647  })
2648  .Case<TransformParamTypeInterface>([&](auto x) {
2649  return llvm::range_size(state.getParams(getHandle()));
2650  })
2651  .DefaultUnreachable("unknown transform dialect type interface");
2652 
2653  auto produceNumOpsError = [&]() {
2654  return emitSilenceableError()
2655  << getHandle() << " expected to contain " << this->getNumResults()
2656  << " payloads but it contains " << numPayloads << " payloads";
2657  };
2658 
2659  // Fail if there are more payload ops than results and no overflow result was
2660  // specified.
2661  if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2662  return produceNumOpsError();
2663 
2664  // Fail if there are more results than payload ops. Unless:
2665  // - "fail_on_payload_too_small" is set to "false", or
2666  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2667  if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2668  (numPayloads != 0 || !getPassThroughEmptyHandle()))
2669  return produceNumOpsError();
2670 
2671  // Distribute payloads.
2672  SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2673  if (getOverflowResult())
2674  resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2675 
2676  auto container = [&]() {
2677  if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2678  return llvm::map_to_vector(
2679  state.getPayloadOps(getHandle()),
2680  [](Operation *op) -> MappedValue { return op; });
2681  }
2682  if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2683  return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2684  [](Value v) -> MappedValue { return v; });
2685  }
2686  assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2687  "unsupported kind of transform dialect type");
2688  return llvm::map_to_vector(state.getParams(getHandle()),
2689  [](Attribute a) -> MappedValue { return a; });
2690  }();
2691 
2692  for (auto &&en : llvm::enumerate(container)) {
2693  int64_t resultNum = en.index();
2694  if (resultNum >= getNumResults())
2695  resultNum = *getOverflowResult();
2696  resultHandles[resultNum].push_back(en.value());
2697  }
2698 
2699  // Set transform op results.
2700  for (auto &&it : llvm::enumerate(resultHandles))
2701  results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2702  it.value());
2703 
2705 }
2706 
2707 void transform::SplitHandleOp::getEffects(
2708  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2709  onlyReadsHandle(getHandleMutable(), effects);
2710  producesHandle(getOperation()->getOpResults(), effects);
2711  // There are no effects on the Payload IR as this is only a handle
2712  // manipulation.
2713 }
2714 
2715 LogicalResult transform::SplitHandleOp::verify() {
2716  if (getOverflowResult().has_value() &&
2717  !(*getOverflowResult() < getNumResults()))
2718  return emitOpError("overflow_result is not a valid result index");
2719 
2720  for (Type resultType : getResultTypes()) {
2721  if (implementSameTransformInterface(getHandle().getType(), resultType))
2722  continue;
2723 
2724  return emitOpError("expects result types to implement the same transform "
2725  "interface as the operand type");
2726  }
2727 
2728  return success();
2729 }
2730 
2731 //===----------------------------------------------------------------------===//
2732 // ReplicateOp
2733 //===----------------------------------------------------------------------===//
2734 
2736 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2737  transform::TransformResults &results,
2738  transform::TransformState &state) {
2739  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2740  for (const auto &en : llvm::enumerate(getHandles())) {
2741  Value handle = en.value();
2742  if (isa<TransformHandleTypeInterface>(handle.getType())) {
2743  SmallVector<Operation *> current =
2744  llvm::to_vector(state.getPayloadOps(handle));
2745  SmallVector<Operation *> payload;
2746  payload.reserve(numRepetitions * current.size());
2747  for (unsigned i = 0; i < numRepetitions; ++i)
2748  llvm::append_range(payload, current);
2749  results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2750  } else {
2751  assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2752  "expected param type");
2753  ArrayRef<Attribute> current = state.getParams(handle);
2754  SmallVector<Attribute> params;
2755  params.reserve(numRepetitions * current.size());
2756  for (unsigned i = 0; i < numRepetitions; ++i)
2757  llvm::append_range(params, current);
2758  results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2759  params);
2760  }
2761  }
2763 }
2764 
2765 void transform::ReplicateOp::getEffects(
2766  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2767  onlyReadsHandle(getPatternMutable(), effects);
2768  onlyReadsHandle(getHandlesMutable(), effects);
2769  producesHandle(getOperation()->getOpResults(), effects);
2770 }
2771 
2772 //===----------------------------------------------------------------------===//
2773 // SequenceOp
2774 //===----------------------------------------------------------------------===//
2775 
2777 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2778  transform::TransformResults &results,
2779  transform::TransformState &state) {
2780  // Map the entry block argument to the list of operations.
2781  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2782  if (failed(mapBlockArguments(state)))
2784 
2785  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2786  results);
2787 }
2788 
2789 static ParseResult parseSequenceOpOperands(
2790  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2791  Type &rootType,
2792  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2793  SmallVectorImpl<Type> &extraBindingTypes) {
2794  OpAsmParser::UnresolvedOperand rootOperand;
2795  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2796  if (!hasRoot.has_value()) {
2797  root = std::nullopt;
2798  return success();
2799  }
2800  if (failed(hasRoot.value()))
2801  return failure();
2802  root = rootOperand;
2803 
2804  if (succeeded(parser.parseOptionalComma())) {
2805  if (failed(parser.parseOperandList(extraBindings)))
2806  return failure();
2807  }
2808  if (failed(parser.parseColon()))
2809  return failure();
2810 
2811  // The paren is truly optional.
2812  (void)parser.parseOptionalLParen();
2813 
2814  if (failed(parser.parseType(rootType))) {
2815  return failure();
2816  }
2817 
2818  if (!extraBindings.empty()) {
2819  if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2820  return failure();
2821  }
2822 
2823  if (extraBindingTypes.size() != extraBindings.size()) {
2824  return parser.emitError(parser.getNameLoc(),
2825  "expected types to be provided for all operands");
2826  }
2827 
2828  // The paren is truly optional.
2829  (void)parser.parseOptionalRParen();
2830  return success();
2831 }
2832 
2834  Value root, Type rootType,
2835  ValueRange extraBindings,
2836  TypeRange extraBindingTypes) {
2837  if (!root)
2838  return;
2839 
2840  printer << root;
2841  bool hasExtras = !extraBindings.empty();
2842  if (hasExtras) {
2843  printer << ", ";
2844  printer.printOperands(extraBindings);
2845  }
2846 
2847  printer << " : ";
2848  if (hasExtras)
2849  printer << "(";
2850 
2851  printer << rootType;
2852  if (hasExtras)
2853  printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2854 }
2855 
2856 /// Returns `true` if the given op operand may be consuming the handle value in
2857 /// the Transform IR. That is, if it may have a Free effect on it.
2859  // Conservatively assume the effect being present in absence of the interface.
2860  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2861  if (!iface)
2862  return true;
2863 
2864  return isHandleConsumed(use.get(), iface);
2865 }
2866 
2867 LogicalResult
2869  function_ref<InFlightDiagnostic()> reportError) {
2870  OpOperand *potentialConsumer = nullptr;
2871  for (OpOperand &use : value.getUses()) {
2872  if (!isValueUsePotentialConsumer(use))
2873  continue;
2874 
2875  if (!potentialConsumer) {
2876  potentialConsumer = &use;
2877  continue;
2878  }
2879 
2880  InFlightDiagnostic diag = reportError()
2881  << " has more than one potential consumer";
2882  diag.attachNote(potentialConsumer->getOwner()->getLoc())
2883  << "used here as operand #" << potentialConsumer->getOperandNumber();
2884  diag.attachNote(use.getOwner()->getLoc())
2885  << "used here as operand #" << use.getOperandNumber();
2886  return diag;
2887  }
2888 
2889  return success();
2890 }
2891 
2892 LogicalResult transform::SequenceOp::verify() {
2893  assert(getBodyBlock()->getNumArguments() >= 1 &&
2894  "the number of arguments must have been verified to be more than 1 by "
2895  "PossibleTopLevelTransformOpTrait");
2896 
2897  if (!getRoot() && !getExtraBindings().empty()) {
2898  return emitOpError()
2899  << "does not expect extra operands when used as top-level";
2900  }
2901 
2902  // Check if a block argument has more than one consuming use.
2903  for (BlockArgument arg : getBodyBlock()->getArguments()) {
2904  if (failed(checkDoubleConsume(arg, [this, arg]() {
2905  return (emitOpError() << "block argument #" << arg.getArgNumber());
2906  }))) {
2907  return failure();
2908  }
2909  }
2910 
2911  // Check properties of the nested operations they cannot check themselves.
2912  for (Operation &child : *getBodyBlock()) {
2913  if (!isa<TransformOpInterface>(child) &&
2914  &child != &getBodyBlock()->back()) {
2916  emitOpError()
2917  << "expected children ops to implement TransformOpInterface";
2918  diag.attachNote(child.getLoc()) << "op without interface";
2919  return diag;
2920  }
2921 
2922  for (OpResult result : child.getResults()) {
2923  auto report = [&]() {
2924  return (child.emitError() << "result #" << result.getResultNumber());
2925  };
2926  if (failed(checkDoubleConsume(result, report)))
2927  return failure();
2928  }
2929  }
2930 
2931  if (!getBodyBlock()->mightHaveTerminator())
2932  return emitOpError() << "expects to have a terminator in the body";
2933 
2934  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2935  getOperation()->getResultTypes()) {
2936  InFlightDiagnostic diag = emitOpError()
2937  << "expects the types of the terminator operands "
2938  "to match the types of the result";
2939  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2940  return diag;
2941  }
2942  return success();
2943 }
2944 
2945 void transform::SequenceOp::getEffects(
2946  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2947  getPotentialTopLevelEffects(effects);
2948 }
2949 
2951 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2952  assert(point == getBody() && "unexpected region index");
2953  if (getOperation()->getNumOperands() > 0)
2954  return getOperation()->getOperands();
2955  return OperandRange(getOperation()->operand_end(),
2956  getOperation()->operand_end());
2957 }
2958 
2959 void transform::SequenceOp::getSuccessorRegions(
2960  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2961  if (point.isParent()) {
2962  Region *bodyRegion = &getBody();
2963  regions.emplace_back(bodyRegion, getNumOperands() != 0
2964  ? bodyRegion->getArguments()
2966  return;
2967  }
2968 
2969  assert(point == getBody() && "unexpected region index");
2970  regions.emplace_back(getOperation()->getResults());
2971 }
2972 
2973 void transform::SequenceOp::getRegionInvocationBounds(
2974  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2975  (void)operands;
2976  bounds.emplace_back(1, 1);
2977 }
2978 
2979 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2980  TypeRange resultTypes,
2981  FailurePropagationMode failurePropagationMode,
2982  Value root,
2983  SequenceBodyBuilderFn bodyBuilder) {
2984  build(builder, state, resultTypes, failurePropagationMode, root,
2985  /*extra_bindings=*/ValueRange());
2986  Type bbArgType = root.getType();
2987  buildSequenceBody(builder, state, bbArgType,
2988  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2989 }
2990 
2991 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2992  TypeRange resultTypes,
2993  FailurePropagationMode failurePropagationMode,
2994  Value root, ValueRange extraBindings,
2995  SequenceBodyBuilderArgsFn bodyBuilder) {
2996  build(builder, state, resultTypes, failurePropagationMode, root,
2997  extraBindings);
2998  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2999  bodyBuilder);
3000 }
3001 
3002 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3003  TypeRange resultTypes,
3004  FailurePropagationMode failurePropagationMode,
3005  Type bbArgType,
3006  SequenceBodyBuilderFn bodyBuilder) {
3007  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3008  /*extra_bindings=*/ValueRange());
3009  buildSequenceBody(builder, state, bbArgType,
3010  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3011 }
3012 
3013 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3014  TypeRange resultTypes,
3015  FailurePropagationMode failurePropagationMode,
3016  Type bbArgType, TypeRange extraBindingTypes,
3017  SequenceBodyBuilderArgsFn bodyBuilder) {
3018  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3019  /*extra_bindings=*/ValueRange());
3020  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3021 }
3022 
3023 //===----------------------------------------------------------------------===//
3024 // PrintOp
3025 //===----------------------------------------------------------------------===//
3026 
3027 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3028  StringRef name) {
3029  if (!name.empty())
3030  result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3031 }
3032 
3033 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3034  Value target, StringRef name) {
3035  result.addOperands({target});
3036  build(builder, result, name);
3037 }
3038 
3040 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3041  transform::TransformResults &results,
3042  transform::TransformState &state) {
3043  llvm::outs() << "[[[ IR printer: ";
3044  if (getName().has_value())
3045  llvm::outs() << *getName() << " ";
3046 
3047  OpPrintingFlags printFlags;
3048  if (getAssumeVerified().value_or(false))
3049  printFlags.assumeVerified();
3050  if (getUseLocalScope().value_or(false))
3051  printFlags.useLocalScope();
3052  if (getSkipRegions().value_or(false))
3053  printFlags.skipRegions();
3054 
3055  if (!getTarget()) {
3056  llvm::outs() << "top-level ]]]\n";
3057  state.getTopLevel()->print(llvm::outs(), printFlags);
3058  llvm::outs() << "\n";
3059  llvm::outs().flush();
3061  }
3062 
3063  llvm::outs() << "]]]\n";
3064  for (Operation *target : state.getPayloadOps(getTarget())) {
3065  target->print(llvm::outs(), printFlags);
3066  llvm::outs() << "\n";
3067  }
3068 
3069  llvm::outs().flush();
3071 }
3072 
3073 void transform::PrintOp::getEffects(
3074  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3075  // We don't really care about mutability here, but `getTarget` now
3076  // unconditionally casts to a specific type before verification could run
3077  // here.
3078  if (!getTargetMutable().empty())
3079  onlyReadsHandle(getTargetMutable()[0], effects);
3080  onlyReadsPayload(effects);
3081 
3082  // There is no resource for stderr file descriptor, so just declare print
3083  // writes into the default resource.
3084  effects.emplace_back(MemoryEffects::Write::get());
3085 }
3086 
3087 //===----------------------------------------------------------------------===//
3088 // VerifyOp
3089 //===----------------------------------------------------------------------===//
3090 
3092 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3093  Operation *target,
3095  transform::TransformState &state) {
3096  if (failed(::mlir::verify(target))) {
3098  << "failed to verify payload op";
3099  diag.attachNote(target->getLoc()) << "payload op";
3100  return diag;
3101  }
3103 }
3104 
3105 void transform::VerifyOp::getEffects(
3106  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3107  transform::onlyReadsHandle(getTargetMutable(), effects);
3108 }
3109 
3110 //===----------------------------------------------------------------------===//
3111 // YieldOp
3112 //===----------------------------------------------------------------------===//
3113 
3114 void transform::YieldOp::getEffects(
3115  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3116  onlyReadsHandle(getOperandsMutable(), effects);
3117 }
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a #dlti.dl_entry attribute.
Definition: DLTI.cpp:38
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue >> blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static ParseResult parseApplyRegisteredPassOptions(OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dynamicOptions)
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DEBUG_TYPE_MATCHER
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, Operation *op, DictionaryAttr options, ValueRange dynamicOptions)
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:262
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:296
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
Definition: AsmPrinter.cpp:288
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:282
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getType() const
Definition: ValueRange.cpp:32
type_range getTypes() const
Definition: ValueRange.cpp:28
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:719
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
Definition: PassManager.h:232
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:52
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
Definition: PassRegistry.h:58
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
static void printOptionValue(raw_ostream &os, const bool &value)
Utility methods for printing option values.
Definition: PassOptions.h:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument, ::mlir::ValueRange)> SequenceBodyBuilderArgsFn
Definition: TransformOps.h:39
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)> SequenceBodyBuilderFn
A builder function that populates the body of a SequenceOp.
Definition: TransformOps.h:36
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition: CSE.cpp:378
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.