MLIR  22.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/IR/Verifier.h"
28 #include "mlir/Pass/PassManager.h"
29 #include "mlir/Pass/PassRegistry.h"
30 #include "mlir/Transforms/CSE.h"
34 #include "llvm/ADT/DenseSet.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/SmallPtrSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Support/DebugLog.h"
41 #include "llvm/Support/ErrorHandling.h"
42 #include "llvm/Support/InterleavedRange.h"
43 #include <optional>
44 
45 #define DEBUG_TYPE "transform-dialect"
46 #define DEBUG_TYPE_MATCHER "transform-matcher"
47 
48 using namespace mlir;
49 
50 static ParseResult parseApplyRegisteredPassOptions(
51  OpAsmParser &parser, DictionaryAttr &options,
52  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
54  Operation *op,
55  DictionaryAttr options,
56  ValueRange dynamicOptions);
57 static ParseResult parseSequenceOpOperands(
58  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
59  Type &rootType,
60  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
61  SmallVectorImpl<Type> &extraBindingTypes);
62 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
63  Value root, Type rootType,
64  ValueRange extraBindings,
65  TypeRange extraBindingTypes);
66 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
67  ArrayAttr matchers, ArrayAttr actions);
68 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
69  ArrayAttr &matchers,
70  ArrayAttr &actions);
71 
72 /// Helper function to check if the given transform op is contained in (or
73 /// equal to) the given payload target op. In that case, an error is returned.
74 /// Transforming transform IR that is currently executing is generally unsafe.
76 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
77  Operation *payload) {
78  Operation *transformAncestor = transform.getOperation();
79  while (transformAncestor) {
80  if (transformAncestor == payload) {
82  transform.emitDefiniteFailure()
83  << "cannot apply transform to itself (or one of its ancestors)";
84  diag.attachNote(payload->getLoc()) << "target payload op";
85  return diag;
86  }
87  transformAncestor = transformAncestor->getParentOp();
88  }
90 }
91 
92 #define GET_OP_CLASSES
93 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
94 
95 //===----------------------------------------------------------------------===//
96 // AlternativesOp
97 //===----------------------------------------------------------------------===//
98 
100 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
101  if (!point.isParent() && getOperation()->getNumOperands() == 1)
102  return getOperation()->getOperands();
103  return OperandRange(getOperation()->operand_end(),
104  getOperation()->operand_end());
105 }
106 
107 void transform::AlternativesOp::getSuccessorRegions(
108  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
109  for (Region &alternative : llvm::drop_begin(
110  getAlternatives(),
111  point.isParent() ? 0
112  : point.getRegionOrNull()->getRegionNumber() + 1)) {
113  regions.emplace_back(&alternative, !getOperands().empty()
114  ? alternative.getArguments()
116  }
117  if (!point.isParent())
118  regions.emplace_back(getOperation()->getResults());
119 }
120 
121 void transform::AlternativesOp::getRegionInvocationBounds(
122  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
123  (void)operands;
124  // The region corresponding to the first alternative is always executed, the
125  // remaining may or may not be executed.
126  bounds.reserve(getNumRegions());
127  bounds.emplace_back(1, 1);
128  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
129 }
130 
132  transform::TransformResults &results) {
133  for (const auto &res : block->getParentOp()->getOpResults())
134  results.set(res, {});
135 }
136 
138 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
140  transform::TransformState &state) {
141  SmallVector<Operation *> originals;
142  if (Value scopeHandle = getScope())
143  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
144  else
145  originals.push_back(state.getTopLevel());
146 
147  for (Operation *original : originals) {
148  if (original->isAncestor(getOperation())) {
149  auto diag = emitDefiniteFailure()
150  << "scope must not contain the transforms being applied";
151  diag.attachNote(original->getLoc()) << "scope";
152  return diag;
153  }
154  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
155  auto diag = emitDefiniteFailure()
156  << "only isolated-from-above ops can be alternative scopes";
157  diag.attachNote(original->getLoc()) << "scope";
158  return diag;
159  }
160  }
161 
162  for (Region &reg : getAlternatives()) {
163  // Clone the scope operations and make the transforms in this alternative
164  // region apply to them by virtue of mapping the block argument (the only
165  // visible handle) to the cloned scope operations. This effectively prevents
166  // the transformation from accessing any IR outside the scope.
167  auto scope = state.make_region_scope(reg);
168  auto clones = llvm::to_vector(
169  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
170  auto deleteClones = llvm::make_scope_exit([&] {
171  for (Operation *clone : clones)
172  clone->erase();
173  });
174  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
176 
177  bool failed = false;
178  for (Operation &transform : reg.front().without_terminator()) {
180  state.applyTransform(cast<TransformOpInterface>(transform));
181  if (result.isSilenceableFailure()) {
182  LDBG() << "alternative failed: " << result.getMessage();
183  failed = true;
184  break;
185  }
186 
187  if (::mlir::failed(result.silence()))
189  }
190 
191  // If all operations in the given alternative succeeded, no need to consider
192  // the rest. Replace the original scoping operation with the clone on which
193  // the transformations were performed.
194  if (!failed) {
195  // We will be using the clones, so cancel their scheduled deletion.
196  deleteClones.release();
197  TrackingListener listener(state, *this);
198  IRRewriter rewriter(getContext(), &listener);
199  for (const auto &kvp : llvm::zip(originals, clones)) {
200  Operation *original = std::get<0>(kvp);
201  Operation *clone = std::get<1>(kvp);
202  original->getBlock()->getOperations().insert(original->getIterator(),
203  clone);
204  rewriter.replaceOp(original, clone->getResults());
205  }
206  detail::forwardTerminatorOperands(&reg.front(), state, results);
208  }
209  }
210  return emitSilenceableError() << "all alternatives failed";
211 }
212 
213 void transform::AlternativesOp::getEffects(
214  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
215  consumesHandle(getOperation()->getOpOperands(), effects);
216  producesHandle(getOperation()->getOpResults(), effects);
217  for (Region *region : getRegions()) {
218  if (!region->empty())
219  producesHandle(region->front().getArguments(), effects);
220  }
221  modifiesPayload(effects);
222 }
223 
224 LogicalResult transform::AlternativesOp::verify() {
225  for (Region &alternative : getAlternatives()) {
226  Block &block = alternative.front();
227  Operation *terminator = block.getTerminator();
228  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
229  InFlightDiagnostic diag = emitOpError()
230  << "expects terminator operands to have the "
231  "same type as results of the operation";
232  diag.attachNote(terminator->getLoc()) << "terminator";
233  return diag;
234  }
235  }
236 
237  return success();
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // AnnotateOp
242 //===----------------------------------------------------------------------===//
243 
245 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
247  transform::TransformState &state) {
248  SmallVector<Operation *> targets =
249  llvm::to_vector(state.getPayloadOps(getTarget()));
250 
252  if (auto paramH = getParam()) {
253  ArrayRef<Attribute> params = state.getParams(paramH);
254  if (params.size() != 1) {
255  if (targets.size() != params.size()) {
256  return emitSilenceableError()
257  << "parameter and target have different payload lengths ("
258  << params.size() << " vs " << targets.size() << ")";
259  }
260  for (auto &&[target, attr] : llvm::zip_equal(targets, params))
261  target->setAttr(getName(), attr);
263  }
264  attr = params[0];
265  }
266  for (auto *target : targets)
267  target->setAttr(getName(), attr);
269 }
270 
271 void transform::AnnotateOp::getEffects(
272  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
273  onlyReadsHandle(getTargetMutable(), effects);
274  onlyReadsHandle(getParamMutable(), effects);
275  modifiesPayload(effects);
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // ApplyCommonSubexpressionEliminationOp
280 //===----------------------------------------------------------------------===//
281 
283 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
284  transform::TransformRewriter &rewriter, Operation *target,
285  ApplyToEachResultList &results, transform::TransformState &state) {
286  // Make sure that this transform is not applied to itself. Modifying the
287  // transform IR while it is being interpreted is generally dangerous.
288  DiagnosedSilenceableFailure payloadCheck =
290  if (!payloadCheck.succeeded())
291  return payloadCheck;
292 
293  DominanceInfo domInfo;
294  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
296 }
297 
298 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
299  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
300  transform::onlyReadsHandle(getTargetMutable(), effects);
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // ApplyDeadCodeEliminationOp
306 //===----------------------------------------------------------------------===//
307 
308 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
309  transform::TransformRewriter &rewriter, Operation *target,
310  ApplyToEachResultList &results, transform::TransformState &state) {
311  // Make sure that this transform is not applied to itself. Modifying the
312  // transform IR while it is being interpreted is generally dangerous.
313  DiagnosedSilenceableFailure payloadCheck =
315  if (!payloadCheck.succeeded())
316  return payloadCheck;
317 
318  // Maintain a worklist of potentially dead ops.
319  SetVector<Operation *> worklist;
320 
321  // Helper function that adds all defining ops of used values (operands and
322  // operands of nested ops).
323  auto addDefiningOpsToWorklist = [&](Operation *op) {
324  op->walk([&](Operation *op) {
325  for (Value v : op->getOperands())
326  if (Operation *defOp = v.getDefiningOp())
327  if (target->isProperAncestor(defOp))
328  worklist.insert(defOp);
329  });
330  };
331 
332  // Helper function that erases an op.
333  auto eraseOp = [&](Operation *op) {
334  // Remove op and nested ops from the worklist.
335  op->walk([&](Operation *op) {
336  const auto *it = llvm::find(worklist, op);
337  if (it != worklist.end())
338  worklist.erase(it);
339  });
340  rewriter.eraseOp(op);
341  };
342 
343  // Initial walk over the IR.
344  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
345  if (op != target && isOpTriviallyDead(op)) {
346  addDefiningOpsToWorklist(op);
347  eraseOp(op);
348  }
349  });
350 
351  // Erase all ops that have become dead.
352  while (!worklist.empty()) {
353  Operation *op = worklist.pop_back_val();
354  if (!isOpTriviallyDead(op))
355  continue;
356  addDefiningOpsToWorklist(op);
357  eraseOp(op);
358  }
359 
361 }
362 
363 void transform::ApplyDeadCodeEliminationOp::getEffects(
364  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
365  transform::onlyReadsHandle(getTargetMutable(), effects);
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // ApplyPatternsOp
371 //===----------------------------------------------------------------------===//
372 
373 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
374  transform::TransformRewriter &rewriter, Operation *target,
375  ApplyToEachResultList &results, transform::TransformState &state) {
376  // Make sure that this transform is not applied to itself. Modifying the
377  // transform IR while it is being interpreted is generally dangerous. Even
378  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
379  // performs many additional simplifications such as dead code elimination.
380  DiagnosedSilenceableFailure payloadCheck =
382  if (!payloadCheck.succeeded())
383  return payloadCheck;
384 
385  // Gather all specified patterns.
386  MLIRContext *ctx = target->getContext();
388  if (!getRegion().empty()) {
389  for (Operation &op : getRegion().front()) {
390  cast<transform::PatternDescriptorOpInterface>(&op)
391  .populatePatternsWithState(patterns, state);
392  }
393  }
394 
395  // Configure the GreedyPatternRewriteDriver.
397  config.setListener(
398  static_cast<RewriterBase::Listener *>(rewriter.getListener()));
399  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
400 
401  config.setMaxIterations(getMaxIterations() == static_cast<uint64_t>(-1)
403  : getMaxIterations());
404  config.setMaxNumRewrites(getMaxNumRewrites() == static_cast<uint64_t>(-1)
406  : getMaxNumRewrites());
407 
408  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
409  // was requested, apply the greedy pattern rewrite only once. (The greedy
410  // pattern rewrite driver already iterates to a fixpoint internally.)
411  bool cseChanged = false;
412  // One or two iterations should be sufficient. Stop iterating after a certain
413  // threshold to make debugging easier.
414  static const int64_t kNumMaxIterations = 50;
415  int64_t iteration = 0;
416  do {
417  LogicalResult result = failure();
418  if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
419  // Op is isolated from above. Apply patterns and also perform region
420  // simplification.
421  result = applyPatternsGreedily(target, frozenPatterns, config);
422  } else {
423  // Manually gather list of ops because the other
424  // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
425  // from above. This way, patterns can be applied to ops that are not
426  // isolated from above. Regions are not being simplified. Furthermore,
427  // only a single greedy rewrite iteration is performed.
429  target->walk([&](Operation *nestedOp) {
430  if (target != nestedOp)
431  ops.push_back(nestedOp);
432  });
433  result = applyOpPatternsGreedily(ops, frozenPatterns, config);
434  }
435 
436  // A failure typically indicates that the pattern application did not
437  // converge.
438  if (failed(result)) {
439  return emitSilenceableFailure(target)
440  << "greedy pattern application failed";
441  }
442 
443  if (getApplyCse()) {
444  DominanceInfo domInfo;
445  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
446  &cseChanged);
447  }
448  } while (cseChanged && ++iteration < kNumMaxIterations);
449 
450  if (iteration == kNumMaxIterations)
451  return emitDefiniteFailure() << "fixpoint iteration did not converge";
452 
454 }
455 
456 LogicalResult transform::ApplyPatternsOp::verify() {
457  if (!getRegion().empty()) {
458  for (Operation &op : getRegion().front()) {
459  if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
460  InFlightDiagnostic diag = emitOpError()
461  << "expected children ops to implement "
462  "PatternDescriptorOpInterface";
463  diag.attachNote(op.getLoc()) << "op without interface";
464  return diag;
465  }
466  }
467  }
468  return success();
469 }
470 
471 void transform::ApplyPatternsOp::getEffects(
472  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
473  transform::onlyReadsHandle(getTargetMutable(), effects);
475 }
476 
477 void transform::ApplyPatternsOp::build(
478  OpBuilder &builder, OperationState &result, Value target,
479  function_ref<void(OpBuilder &, Location)> bodyBuilder) {
480  result.addOperands(target);
481 
482  OpBuilder::InsertionGuard g(builder);
483  Region *region = result.addRegion();
484  builder.createBlock(region);
485  if (bodyBuilder)
486  bodyBuilder(builder, result.location);
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // ApplyCanonicalizationPatternsOp
491 //===----------------------------------------------------------------------===//
492 
493 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
495  MLIRContext *ctx = patterns.getContext();
496  for (Dialect *dialect : ctx->getLoadedDialects())
497  dialect->getCanonicalizationPatterns(patterns);
499  op.getCanonicalizationPatterns(patterns, ctx);
500 }
501 
502 //===----------------------------------------------------------------------===//
503 // ApplyConversionPatternsOp
504 //===----------------------------------------------------------------------===//
505 
506 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
509  MLIRContext *ctx = getContext();
510 
511  // Instantiate the default type converter if a type converter builder is
512  // specified.
513  std::unique_ptr<TypeConverter> defaultTypeConverter;
514  transform::TypeConverterBuilderOpInterface typeConverterBuilder =
515  getDefaultTypeConverter();
516  if (typeConverterBuilder)
517  defaultTypeConverter = typeConverterBuilder.getTypeConverter();
518 
519  // Configure conversion target.
520  ConversionTarget conversionTarget(*getContext());
521  if (getLegalOps())
522  for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
523  conversionTarget.addLegalOp(
524  OperationName(cast<StringAttr>(attr).getValue(), ctx));
525  if (getIllegalOps())
526  for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
527  conversionTarget.addIllegalOp(
528  OperationName(cast<StringAttr>(attr).getValue(), ctx));
529  if (getLegalDialects())
530  for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
531  conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
532  if (getIllegalDialects())
533  for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
534  conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
535 
536  // Gather all specified patterns.
538  // Need to keep the converters alive until after pattern application because
539  // the patterns take a reference to an object that would otherwise get out of
540  // scope.
541  SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
542  if (!getPatterns().empty()) {
543  for (Operation &op : getPatterns().front()) {
544  auto descriptor =
545  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
546 
547  // Check if this pattern set specifies a type converter.
548  std::unique_ptr<TypeConverter> typeConverter =
549  descriptor.getTypeConverter();
550  TypeConverter *converter = nullptr;
551  if (typeConverter) {
552  keepAliveConverters.emplace_back(std::move(typeConverter));
553  converter = keepAliveConverters.back().get();
554  } else {
555  // No type converter specified: Use the default type converter.
556  if (!defaultTypeConverter) {
557  auto diag = emitDefiniteFailure()
558  << "pattern descriptor does not specify type "
559  "converter and apply_conversion_patterns op has "
560  "no default type converter";
561  diag.attachNote(op.getLoc()) << "pattern descriptor op";
562  return diag;
563  }
564  converter = defaultTypeConverter.get();
565  }
566 
567  // Add descriptor-specific updates to the conversion target, which may
568  // depend on the final type converter. In structural converters, the
569  // legality of types dictates the dynamic legality of an operation.
570  descriptor.populateConversionTargetRules(*converter, conversionTarget);
571 
572  descriptor.populatePatterns(*converter, patterns);
573  }
574  }
575 
576  // Attach a tracking listener if handles should be preserved. We configure the
577  // listener to allow op replacements with different names, as conversion
578  // patterns typically replace ops with replacement ops that have a different
579  // name.
580  TrackingListenerConfig trackingConfig;
581  trackingConfig.requireMatchingReplacementOpName = false;
582  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
583  ConversionConfig conversionConfig;
584  if (getPreserveHandles())
585  conversionConfig.listener = &trackingListener;
586 
587  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
588  for (Operation *target : state.getPayloadOps(getTarget())) {
589  // Make sure that this transform is not applied to itself. Modifying the
590  // transform IR while it is being interpreted is generally dangerous.
591  DiagnosedSilenceableFailure payloadCheck =
593  if (!payloadCheck.succeeded())
594  return payloadCheck;
595 
596  LogicalResult status = failure();
597  if (getPartialConversion()) {
598  status = applyPartialConversion(target, conversionTarget, frozenPatterns,
599  conversionConfig);
600  } else {
601  status = applyFullConversion(target, conversionTarget, frozenPatterns,
602  conversionConfig);
603  }
604 
605  // Check dialect conversion state.
607  if (failed(status)) {
608  diag = emitSilenceableError() << "dialect conversion failed";
609  diag.attachNote(target->getLoc()) << "target op";
610  }
611 
612  // Check tracking listener error state.
613  DiagnosedSilenceableFailure trackingFailure =
614  trackingListener.checkAndResetError();
615  if (!trackingFailure.succeeded()) {
616  if (diag.succeeded()) {
617  // Tracking failure is the only failure.
618  return trackingFailure;
619  } else {
620  diag.attachNote() << "tracking listener also failed: "
621  << trackingFailure.getMessage();
622  (void)trackingFailure.silence();
623  }
624  }
625 
626  if (!diag.succeeded())
627  return diag;
628  }
629 
631 }
632 
634  if (getNumRegions() != 1 && getNumRegions() != 2)
635  return emitOpError() << "expected 1 or 2 regions";
636  if (!getPatterns().empty()) {
637  for (Operation &op : getPatterns().front()) {
638  if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
640  emitOpError() << "expected pattern children ops to implement "
641  "ConversionPatternDescriptorOpInterface";
642  diag.attachNote(op.getLoc()) << "op without interface";
643  return diag;
644  }
645  }
646  }
647  if (getNumRegions() == 2) {
648  Region &typeConverterRegion = getRegion(1);
649  if (!llvm::hasSingleElement(typeConverterRegion.front()))
650  return emitOpError()
651  << "expected exactly one op in default type converter region";
652  Operation *maybeTypeConverter = &typeConverterRegion.front().front();
653  auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
654  maybeTypeConverter);
655  if (!typeConverterOp) {
656  InFlightDiagnostic diag = emitOpError()
657  << "expected default converter child op to "
658  "implement TypeConverterBuilderOpInterface";
659  diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
660  return diag;
661  }
662  // Check default type converter type.
663  if (!getPatterns().empty()) {
664  for (Operation &op : getPatterns().front()) {
665  auto descriptor =
666  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
667  if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
668  return failure();
669  }
670  }
671  }
672  return success();
673 }
674 
675 void transform::ApplyConversionPatternsOp::getEffects(
676  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
677  if (!getPreserveHandles()) {
678  transform::consumesHandle(getTargetMutable(), effects);
679  } else {
680  transform::onlyReadsHandle(getTargetMutable(), effects);
681  }
683 }
684 
685 void transform::ApplyConversionPatternsOp::build(
686  OpBuilder &builder, OperationState &result, Value target,
687  function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
688  function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
689  result.addOperands(target);
690 
691  {
692  OpBuilder::InsertionGuard g(builder);
693  Region *region1 = result.addRegion();
694  builder.createBlock(region1);
695  if (patternsBodyBuilder)
696  patternsBodyBuilder(builder, result.location);
697  }
698  {
699  OpBuilder::InsertionGuard g(builder);
700  Region *region2 = result.addRegion();
701  builder.createBlock(region2);
702  if (typeConverterBodyBuilder)
703  typeConverterBodyBuilder(builder, result.location);
704  }
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // ApplyToLLVMConversionPatternsOp
709 //===----------------------------------------------------------------------===//
710 
711 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
712  TypeConverter &typeConverter, RewritePatternSet &patterns) {
713  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
714  assert(dialect && "expected that dialect is loaded");
715  auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
716  // ConversionTarget is currently ignored because the enclosing
717  // apply_conversion_patterns op sets up its own ConversionTarget.
718  ConversionTarget target(*getContext());
719  iface->populateConvertToLLVMConversionPatterns(
720  target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
721 }
722 
723 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
724  transform::TypeConverterBuilderOpInterface builder) {
725  if (builder.getTypeConverterType() != "LLVMTypeConverter")
726  return emitOpError("expected LLVMTypeConverter");
727  return success();
728 }
729 
731  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
732  if (!dialect)
733  return emitOpError("unknown dialect or dialect not loaded: ")
734  << getDialectName();
735  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
736  if (!iface)
737  return emitOpError(
738  "dialect does not implement ConvertToLLVMPatternInterface or "
739  "extension was not loaded: ")
740  << getDialectName();
741  return success();
742 }
743 
744 //===----------------------------------------------------------------------===//
745 // ApplyLoopInvariantCodeMotionOp
746 //===----------------------------------------------------------------------===//
747 
749 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
750  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
752  transform::TransformState &state) {
753  // Currently, LICM does not remove operations, so we don't need tracking.
754  // If this ever changes, add a LICM entry point that takes a rewriter.
755  moveLoopInvariantCode(target);
757 }
758 
759 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
760  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
761  transform::onlyReadsHandle(getTargetMutable(), effects);
763 }
764 
765 //===----------------------------------------------------------------------===//
766 // ApplyRegisteredPassOp
767 //===----------------------------------------------------------------------===//
768 
769 void transform::ApplyRegisteredPassOp::getEffects(
770  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
771  consumesHandle(getTargetMutable(), effects);
772  onlyReadsHandle(getDynamicOptionsMutable(), effects);
773  producesHandle(getOperation()->getOpResults(), effects);
774  modifiesPayload(effects);
775 }
776 
778 transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
780  transform::TransformState &state) {
781  // Obtain a single options-string to pass to the pass(-pipeline) from options
782  // passed in as a dictionary of keys mapping to values which are either
783  // attributes or param-operands pointing to attributes.
784  OperandRange dynamicOptions = getDynamicOptions();
785 
786  std::string options;
787  llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
788 
789  // A helper to convert an option's attribute value into a corresponding
790  // string representation, with the ability to obtain the attr(s) from a param.
791  std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
792  if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
793  // The corresponding value attribute(s) is/are passed in via a param.
794  // Obtain the param-operand via its specified index.
795  int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
796  assert(dynamicOptionIdx < static_cast<int64_t>(dynamicOptions.size()) &&
797  "the number of ParamOperandAttrs in the options DictionaryAttr"
798  "should be the same as the number of options passed as params");
799  ArrayRef<Attribute> attrsAssociatedToParam =
800  state.getParams(dynamicOptions[dynamicOptionIdx]);
801  // Recursive so as to append all attrs associated to the param.
802  llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
803  ",");
804  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
805  // Recursive so as to append all nested attrs of the array.
806  llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
807  } else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
808  // Convert to unquoted string.
809  optionsStream << strAttr.getValue().str();
810  } else {
811  // For all other attributes, ask the attr to print itself (without type).
812  valueAttr.print(optionsStream, /*elideType=*/true);
813  }
814  };
815 
816  // Convert the options DictionaryAttr into a single string.
817  llvm::interleave(
818  getOptions(), optionsStream,
819  [&](auto namedAttribute) {
820  optionsStream << namedAttribute.getName().str(); // Append the key.
821  optionsStream << "="; // And the key-value separator.
822  appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
823  },
824  " ");
825  optionsStream.flush();
826 
827  // Get pass or pass pipeline from registry.
828  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
829  if (!info)
830  info = PassInfo::lookup(getPassName());
831  if (!info)
832  return emitDefiniteFailure()
833  << "unknown pass or pass pipeline: " << getPassName();
834 
835  // Create pass manager and add the pass or pass pipeline.
836  PassManager pm(getContext());
837  if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
838  emitError(msg);
839  return failure();
840  }))) {
841  return emitDefiniteFailure()
842  << "failed to add pass or pass pipeline to pipeline: "
843  << getPassName();
844  }
845 
846  auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
847  for (Operation *target : targets) {
848  // Make sure that this transform is not applied to itself. Modifying the
849  // transform IR while it is being interpreted is generally dangerous. Even
850  // more so when applying passes because they may perform a wide range of IR
851  // modifications.
852  DiagnosedSilenceableFailure payloadCheck =
854  if (!payloadCheck.succeeded())
855  return payloadCheck;
856 
857  // Run the pass or pass pipeline on the current target operation.
858  if (failed(pm.run(target))) {
859  auto diag = emitSilenceableError() << "pass pipeline failed";
860  diag.attachNote(target->getLoc()) << "target op";
861  return diag;
862  }
863  }
864 
865  // The applied pass will have directly modified the payload IR(s).
866  results.set(llvm::cast<OpResult>(getResult()), targets);
868 }
869 
871  OpAsmParser &parser, DictionaryAttr &options,
872  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
873  // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
874  SmallVector<NamedAttribute> keyValuePairs;
875  size_t dynamicOptionsIdx = 0;
876 
877  // Helper for allowing parsing of option values which can be of the form:
878  // - a normal attribute
879  // - an operand (which would be converted to an attr referring to the operand)
880  // - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
881  std::function<ParseResult(Attribute &)> parseValue =
882  [&](Attribute &valueAttr) -> ParseResult {
883  // Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
884  if (succeeded(parser.parseOptionalLSquare())) {
886 
887  // Recursively parse the array's elements, which might be operands.
888  if (parser.parseCommaSeparatedList(
890  [&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
891  " in options dictionary") ||
892  parser.parseRSquare())
893  return failure(); // NB: Attempted parse should've output error message.
894 
895  valueAttr = ArrayAttr::get(parser.getContext(), attrs);
896 
897  return success();
898  }
899 
900  // Parse the value, which can be either an attribute or an operand.
901  OptionalParseResult parsedValueAttr =
902  parser.parseOptionalAttribute(valueAttr);
903  if (!parsedValueAttr.has_value()) {
905  ParseResult parsedOperand = parser.parseOperand(operand);
906  if (failed(parsedOperand))
907  return failure(); // NB: Attempted parse should've output error message.
908  // To make use of the operand, we need to store it in the options dict.
909  // As SSA-values cannot occur in attributes, what we do instead is store
910  // an attribute in its place that contains the index of the param-operand,
911  // so that an attr-value associated to the param can be resolved later on.
912  dynamicOptions.push_back(operand);
913  auto wrappedIndex = IntegerAttr::get(
914  IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
915  valueAttr =
916  transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
917  } else if (failed(parsedValueAttr.value())) {
918  return failure(); // NB: Attempted parse should have output error message.
919  } else if (isa<transform::ParamOperandAttr>(valueAttr)) {
920  return parser.emitError(parser.getCurrentLocation())
921  << "the param_operand attribute is a marker reserved for "
922  << "indicating a value will be passed via params and is only used "
923  << "in the generic print format";
924  }
925 
926  return success();
927  };
928 
929  // Helper for `key = value`-pair parsing where `key` is a bare identifier or a
930  // string and `value` looks like either an attribute or an operand-in-an-attr.
931  std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
932  std::string key;
933  Attribute valueAttr;
934 
935  if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
936  return parser.emitError(parser.getCurrentLocation())
937  << "expected key to either be an identifier or a string";
938 
939  if (failed(parser.parseEqual()))
940  return parser.emitError(parser.getCurrentLocation())
941  << "expected '=' after key in key-value pair";
942 
943  if (failed(parseValue(valueAttr)))
944  return parser.emitError(parser.getCurrentLocation())
945  << "expected a valid attribute or operand as value associated "
946  << "to key '" << key << "'";
947 
948  keyValuePairs.push_back(NamedAttribute(key, valueAttr));
949 
950  return success();
951  };
952 
955  " in options dictionary"))
956  return failure(); // NB: Attempted parse should have output error message.
957 
958  if (DictionaryAttr::findDuplicate(
959  keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
960  .has_value())
961  return parser.emitError(parser.getCurrentLocation())
962  << "duplicate keys found in options dictionary";
963 
964  options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
965 
966  return success();
967 }
968 
970  Operation *op,
971  DictionaryAttr options,
972  ValueRange dynamicOptions) {
973  if (options.empty())
974  return;
975 
976  std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
977  if (auto paramOperandAttr =
978  dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
979  // Resolve index of param-operand to its actual SSA-value and print that.
980  printer.printOperand(
981  dynamicOptions[paramOperandAttr.getIndex().getInt()]);
982  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
983  // This case is so that ArrayAttr-contained operands are pretty-printed.
984  printer << "[";
985  llvm::interleaveComma(arrayAttr, printer, printOptionValue);
986  printer << "]";
987  } else {
988  printer.printAttribute(valueAttr);
989  }
990  };
991 
992  printer << "{";
993  llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
994  printer << namedAttribute.getName();
995  printer << " = ";
996  printOptionValue(namedAttribute.getValue());
997  });
998  printer << "}";
999 }
1000 
1002  // Check that there is a one-to-one correspondence between param operands
1003  // and references to dynamic options in the options dictionary.
1004 
1005  auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
1006 
1007  // Helper for option values to mark seen operands as having been seen (once).
1008  std::function<LogicalResult(Attribute)> checkOptionValue =
1009  [&](Attribute valueAttr) -> LogicalResult {
1010  if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
1011  int64_t dynamicOptionIdx = paramOperand.getIndex().getInt();
1012  if (dynamicOptionIdx < 0 ||
1013  dynamicOptionIdx >= static_cast<int64_t>(dynamicOptions.size()))
1014  return emitOpError()
1015  << "dynamic option index " << dynamicOptionIdx
1016  << " is out of bounds for the number of dynamic options: "
1017  << dynamicOptions.size();
1018  if (dynamicOptions[dynamicOptionIdx] == nullptr)
1019  return emitOpError() << "dynamic option index " << dynamicOptionIdx
1020  << " is already used in options";
1021  dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
1022  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1023  // Recurse into ArrayAttrs as they may contain references to operands.
1024  for (auto eltAttr : arrayAttr)
1025  if (failed(checkOptionValue(eltAttr)))
1026  return failure();
1027  }
1028  return success();
1029  };
1030 
1031  for (NamedAttribute namedAttr : getOptions())
1032  if (failed(checkOptionValue(namedAttr.getValue())))
1033  return failure();
1034 
1035  // All dynamicOptions-params seen in the dict will have been set to null.
1036  for (Value dynamicOption : dynamicOptions)
1037  if (dynamicOption)
1038  return emitOpError() << "a param operand does not have a corresponding "
1039  << "param_operand attr in the options dict";
1040 
1041  return success();
1042 }
1043 
1044 //===----------------------------------------------------------------------===//
1045 // CastOp
1046 //===----------------------------------------------------------------------===//
1047 
1049 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
1050  Operation *target, ApplyToEachResultList &results,
1051  transform::TransformState &state) {
1052  results.push_back(target);
1054 }
1055 
1056 void transform::CastOp::getEffects(
1057  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1058  onlyReadsPayload(effects);
1059  onlyReadsHandle(getInputMutable(), effects);
1060  producesHandle(getOperation()->getOpResults(), effects);
1061 }
1062 
1063 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1064  assert(inputs.size() == 1 && "expected one input");
1065  assert(outputs.size() == 1 && "expected one output");
1066  return llvm::all_of(
1067  std::initializer_list<Type>{inputs.front(), outputs.front()},
1068  llvm::IsaPred<transform::TransformHandleTypeInterface>);
1069 }
1070 
1071 //===----------------------------------------------------------------------===//
1072 // CollectMatchingOp
1073 //===----------------------------------------------------------------------===//
1074 
1075 /// Applies matcher operations from the given `block` using
1076 /// `blockArgumentMapping` to initialize block arguments. Updates `state`
1077 /// accordingly. If any of the matcher produces a silenceable failure, discards
1078 /// it (printing the content to the debug output stream) and returns failure. If
1079 /// any of the matchers produces a definite failure, reports it and returns
1080 /// failure. If all matchers in the block succeed, populates `mappings` with the
1081 /// payload entities associated with the block terminator operands. Note that
1082 /// `mappings` will be cleared before that.
1085  ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
1087  SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
1088  assert(block.getParent() && "cannot match using a detached block");
1089  auto matchScope = state.make_region_scope(*block.getParent());
1090  if (failed(
1091  state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
1093 
1094  for (Operation &match : block.without_terminator()) {
1095  if (!isa<transform::MatchOpInterface>(match)) {
1096  return emitDefiniteFailure(match.getLoc())
1097  << "expected operations in the match part to "
1098  "implement MatchOpInterface";
1099  }
1101  state.applyTransform(cast<transform::TransformOpInterface>(match));
1102  if (diag.succeeded())
1103  continue;
1104 
1105  return diag;
1106  }
1107 
1108  // Remember the values mapped to the terminator operands so we can
1109  // forward them to the action.
1110  ValueRange yieldedValues = block.getTerminator()->getOperands();
1111  // Our contract with the caller is that the mappings will contain only the
1112  // newly mapped values, clear the rest.
1113  mappings.clear();
1114  transform::detail::prepareValueMappings(mappings, yieldedValues, state);
1116 }
1117 
1118 /// Returns `true` if both types implement one of the interfaces provided as
1119 /// template parameters.
1120 template <typename... Tys>
1121 static bool implementSameInterface(Type t1, Type t2) {
1122  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
1123 }
1124 
1125 /// Returns `true` if both types implement one of the transform dialect
1126 /// interfaces.
1128  return implementSameInterface<transform::TransformHandleTypeInterface,
1129  transform::TransformParamTypeInterface,
1130  transform::TransformValueHandleTypeInterface>(
1131  t1, t2);
1132 }
1133 
1134 //===----------------------------------------------------------------------===//
1135 // CollectMatchingOp
1136 //===----------------------------------------------------------------------===//
1137 
1139 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
1140  transform::TransformResults &results,
1141  transform::TransformState &state) {
1142  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
1143  getOperation(), getMatcher());
1144  if (matcher.isExternal()) {
1145  return emitDefiniteFailure()
1146  << "unresolved external symbol " << getMatcher();
1147  }
1148 
1149  SmallVector<SmallVector<MappedValue>, 2> rawResults;
1150  rawResults.resize(getOperation()->getNumResults());
1151  std::optional<DiagnosedSilenceableFailure> maybeFailure;
1152  for (Operation *root : state.getPayloadOps(getRoot())) {
1153  WalkResult walkResult = root->walk([&](Operation *op) {
1154  LDBG(1, DEBUG_TYPE_MATCHER)
1155  << "matching "
1156  << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1157  << " @" << op;
1158 
1159  // Try matching.
1161  SmallVector<transform::MappedValue> inputMapping({op});
1163  matcher.getFunctionBody().front(),
1164  ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
1165  mappings);
1166  if (diag.isDefiniteFailure())
1167  return WalkResult::interrupt();
1168  if (diag.isSilenceableFailure()) {
1169  LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName()
1170  << " failed: " << diag.getMessage();
1171  return WalkResult::advance();
1172  }
1173 
1174  // If succeeded, collect results.
1175  for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
1176  if (mapping.size() != 1) {
1177  maybeFailure.emplace(emitSilenceableError()
1178  << "result #" << i << ", associated with "
1179  << mapping.size()
1180  << " payload objects, expected 1");
1181  return WalkResult::interrupt();
1182  }
1183  rawResults[i].push_back(mapping[0]);
1184  }
1185  return WalkResult::advance();
1186  });
1187  if (walkResult.wasInterrupted())
1188  return std::move(*maybeFailure);
1189  assert(!maybeFailure && "failure set but the walk was not interrupted");
1190 
1191  for (auto &&[opResult, rawResult] :
1192  llvm::zip_equal(getOperation()->getResults(), rawResults)) {
1193  results.setMappedValues(opResult, rawResult);
1194  }
1195  }
1197 }
1198 
1199 void transform::CollectMatchingOp::getEffects(
1200  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1201  onlyReadsHandle(getRootMutable(), effects);
1202  producesHandle(getOperation()->getOpResults(), effects);
1203  onlyReadsPayload(effects);
1204 }
1205 
1206 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
1207  SymbolTableCollection &symbolTable) {
1208  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1209  symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
1210  if (!matcherSymbol ||
1211  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1212  return emitError() << "unresolved matcher symbol " << getMatcher();
1213 
1214  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
1215  if (argumentTypes.size() != 1 ||
1216  !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
1217  return emitError()
1218  << "expected the matcher to take one operation handle argument";
1219  }
1220  if (!matcherSymbol.getArgAttr(
1221  0, transform::TransformDialect::kArgReadOnlyAttrName)) {
1222  return emitError() << "expected the matcher argument to be marked readonly";
1223  }
1224 
1225  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
1226  if (resultTypes.size() != getOperation()->getNumResults()) {
1227  return emitError()
1228  << "expected the matcher to yield as many values as op has results ("
1229  << getOperation()->getNumResults() << "), got "
1230  << resultTypes.size();
1231  }
1232 
1233  for (auto &&[i, matcherType, resultType] :
1234  llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1235  if (implementSameTransformInterface(matcherType, resultType))
1236  continue;
1237 
1238  return emitError()
1239  << "mismatching type interfaces for matcher result and op result #"
1240  << i;
1241  }
1242 
1243  return success();
1244 }
1245 
1246 //===----------------------------------------------------------------------===//
1247 // ForeachMatchOp
1248 //===----------------------------------------------------------------------===//
1249 
1250 // This is fine because nothing is actually consumed by this op.
1251 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1252 
1254 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1255  transform::TransformResults &results,
1256  transform::TransformState &state) {
1258  matchActionPairs;
1259  matchActionPairs.reserve(getMatchers().size());
1260  SymbolTableCollection symbolTable;
1261  for (auto &&[matcher, action] :
1262  llvm::zip_equal(getMatchers(), getActions())) {
1263  auto matcherSymbol =
1264  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1265  getOperation(), cast<SymbolRefAttr>(matcher));
1266  auto actionSymbol =
1267  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1268  getOperation(), cast<SymbolRefAttr>(action));
1269  assert(matcherSymbol && actionSymbol &&
1270  "unresolved symbols not caught by the verifier");
1271 
1272  if (matcherSymbol.isExternal())
1273  return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1274  if (actionSymbol.isExternal())
1275  return emitDefiniteFailure() << "unresolved external symbol " << action;
1276 
1277  matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1278  }
1279 
1280  DiagnosedSilenceableFailure overallDiag =
1282 
1283  SmallVector<SmallVector<MappedValue>> matchInputMapping;
1284  SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1285  SmallVector<SmallVector<MappedValue>> actionResultMapping;
1286  // Explicitly add the mapping for the first block argument (the op being
1287  // matched).
1288  matchInputMapping.emplace_back();
1289  transform::detail::prepareValueMappings(matchInputMapping,
1290  getForwardedInputs(), state);
1291  SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1292  actionResultMapping.resize(getForwardedOutputs().size());
1293 
1294  for (Operation *root : state.getPayloadOps(getRoot())) {
1295  WalkResult walkResult = root->walk([&](Operation *op) {
1296  // If getRestrictRoot is not present, skip over the root op itself so we
1297  // don't invalidate it.
1298  if (!getRestrictRoot() && op == root)
1299  return WalkResult::advance();
1300 
1301  LDBG(1, DEBUG_TYPE_MATCHER)
1302  << "matching "
1303  << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions())
1304  << " @" << op;
1305 
1306  firstMatchArgument.clear();
1307  firstMatchArgument.push_back(op);
1308 
1309  // Try all the match/action pairs until the first successful match.
1310  for (auto [matcher, action] : matchActionPairs) {
1312  matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1313  state, matchOutputMapping);
1314  if (diag.isDefiniteFailure())
1315  return WalkResult::interrupt();
1316  if (diag.isSilenceableFailure()) {
1317  LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName()
1318  << " failed: " << diag.getMessage();
1319  continue;
1320  }
1321 
1322  auto scope = state.make_region_scope(action.getFunctionBody());
1323  if (failed(state.mapBlockArguments(
1324  action.getFunctionBody().front().getArguments(),
1325  matchOutputMapping))) {
1326  return WalkResult::interrupt();
1327  }
1328 
1329  for (Operation &transform :
1330  action.getFunctionBody().front().without_terminator()) {
1332  state.applyTransform(cast<TransformOpInterface>(transform));
1333  if (result.isDefiniteFailure())
1334  return WalkResult::interrupt();
1335  if (result.isSilenceableFailure()) {
1336  if (overallDiag.succeeded()) {
1337  overallDiag = emitSilenceableError() << "actions failed";
1338  }
1339  overallDiag.attachNote(action->getLoc())
1340  << "failed action: " << result.getMessage();
1341  overallDiag.attachNote(op->getLoc())
1342  << "when applied to this matching payload";
1343  (void)result.silence();
1344  continue;
1345  }
1346  }
1348  MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1349  action.getFunctionBody().front().getTerminator()->getOperands(),
1350  state, getFlattenResults()))) {
1352  << "action @" << action.getName()
1353  << " has results associated with multiple payload entities, "
1354  "but flattening was not requested";
1355  return WalkResult::interrupt();
1356  }
1357  break;
1358  }
1359  return WalkResult::advance();
1360  });
1361  if (walkResult.wasInterrupted())
1363  }
1364 
1365  // The root operation should not have been affected, so we can just reassign
1366  // the payload to the result. Note that we need to consume the root handle to
1367  // make sure any handles to operations inside, that could have been affected
1368  // by actions, are invalidated.
1369  results.set(llvm::cast<OpResult>(getUpdated()),
1370  state.getPayloadOps(getRoot()));
1371  for (auto &&[result, mapping] :
1372  llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1373  results.setMappedValues(result, mapping);
1374  }
1375  return overallDiag;
1376 }
1377 
1378 void transform::ForeachMatchOp::getAsmResultNames(
1379  OpAsmSetValueNameFn setNameFn) {
1380  setNameFn(getUpdated(), "updated_root");
1381  for (Value v : getForwardedOutputs()) {
1382  setNameFn(v, "yielded");
1383  }
1384 }
1385 
1386 void transform::ForeachMatchOp::getEffects(
1387  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1388  // Bail if invalid.
1389  if (getOperation()->getNumOperands() < 1 ||
1390  getOperation()->getNumResults() < 1) {
1391  return modifiesPayload(effects);
1392  }
1393 
1394  consumesHandle(getRootMutable(), effects);
1395  onlyReadsHandle(getForwardedInputsMutable(), effects);
1396  producesHandle(getOperation()->getOpResults(), effects);
1397  modifiesPayload(effects);
1398 }
1399 
1400 /// Parses the comma-separated list of symbol reference pairs of the format
1401 /// `@matcher -> @action`.
1402 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1403  ArrayAttr &matchers,
1404  ArrayAttr &actions) {
1405  StringAttr matcher;
1406  StringAttr action;
1407  SmallVector<Attribute> matcherList;
1408  SmallVector<Attribute> actionList;
1409  do {
1410  if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1411  parser.parseSymbolName(action)) {
1412  return failure();
1413  }
1414  matcherList.push_back(SymbolRefAttr::get(matcher));
1415  actionList.push_back(SymbolRefAttr::get(action));
1416  } while (parser.parseOptionalComma().succeeded());
1417 
1418  matchers = parser.getBuilder().getArrayAttr(matcherList);
1419  actions = parser.getBuilder().getArrayAttr(actionList);
1420  return success();
1421 }
1422 
1423 /// Prints the comma-separated list of symbol reference pairs of the format
1424 /// `@matcher -> @action`.
1426  ArrayAttr matchers, ArrayAttr actions) {
1427  printer.increaseIndent();
1428  printer.increaseIndent();
1429  for (auto &&[matcher, action, idx] : llvm::zip_equal(
1430  matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1431  printer.printNewline();
1432  printer << cast<SymbolRefAttr>(matcher) << " -> "
1433  << cast<SymbolRefAttr>(action);
1434  if (idx != matchers.size() - 1)
1435  printer << ", ";
1436  }
1437  printer.decreaseIndent();
1438  printer.decreaseIndent();
1439 }
1440 
1441 LogicalResult transform::ForeachMatchOp::verify() {
1442  if (getMatchers().size() != getActions().size())
1443  return emitOpError() << "expected the same number of matchers and actions";
1444  if (getMatchers().empty())
1445  return emitOpError() << "expected at least one match/action pair";
1446 
1447  llvm::SmallPtrSet<Attribute, 8> matcherNames;
1448  for (Attribute name : getMatchers()) {
1449  if (matcherNames.insert(name).second)
1450  continue;
1451  emitWarning() << "matcher " << name
1452  << " is used more than once, only the first match will apply";
1453  }
1454 
1455  return success();
1456 }
1457 
1458 /// Checks that the attributes of the function-like operation have correct
1459 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
1460 /// annotations being present even if they can be inferred from the body.
1462 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1463  bool alsoVerifyInternal = false) {
1464  auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1465  llvm::SmallDenseSet<unsigned> consumedArguments;
1466  if (!op.isExternal()) {
1467  transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1468  consumedArguments);
1469  }
1470  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1471  bool isConsumed =
1472  op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1473  nullptr;
1474  bool isReadOnly =
1475  op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1476  nullptr;
1477  if (isConsumed && isReadOnly) {
1478  return transformOp.emitSilenceableError()
1479  << "argument #" << i << " cannot be both readonly and consumed";
1480  }
1481  if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1482  return transformOp.emitSilenceableError()
1483  << "must provide consumed/readonly status for arguments of "
1484  "external or called ops";
1485  }
1486  if (op.isExternal())
1487  continue;
1488 
1489  if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1490  return transformOp.emitSilenceableError()
1491  << "argument #" << i
1492  << " is consumed in the body but is not marked as such";
1493  }
1494  if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1495  // Cannot use op.emitWarning() here as it would attempt to verify the op
1496  // before printing, resulting in infinite recursion.
1497  emitWarning(op->getLoc())
1498  << "op argument #" << i
1499  << " is not consumed in the body but is marked as consumed";
1500  }
1501  }
1503 }
1504 
1505 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1506  SymbolTableCollection &symbolTable) {
1507  assert(getMatchers().size() == getActions().size());
1508  auto consumedAttr =
1509  StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1510  for (auto &&[matcher, action] :
1511  llvm::zip_equal(getMatchers(), getActions())) {
1512  // Presence and typing.
1513  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1514  symbolTable.lookupNearestSymbolFrom(getOperation(),
1515  cast<SymbolRefAttr>(matcher)));
1516  auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1517  symbolTable.lookupNearestSymbolFrom(getOperation(),
1518  cast<SymbolRefAttr>(action)));
1519  if (!matcherSymbol ||
1520  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1521  return emitError() << "unresolved matcher symbol " << matcher;
1522  if (!actionSymbol ||
1523  !isa<TransformOpInterface>(actionSymbol.getOperation()))
1524  return emitError() << "unresolved action symbol " << action;
1525 
1526  if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1527  /*emitWarnings=*/false,
1528  /*alsoVerifyInternal=*/true)
1529  .checkAndReport())) {
1530  return failure();
1531  }
1533  /*emitWarnings=*/false,
1534  /*alsoVerifyInternal=*/true)
1535  .checkAndReport())) {
1536  return failure();
1537  }
1538 
1539  // Input -> matcher forwarding.
1540  TypeRange operandTypes = getOperandTypes();
1541  TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1542  if (operandTypes.size() != matcherArguments.size()) {
1544  emitError() << "the number of operands (" << operandTypes.size()
1545  << ") doesn't match the number of matcher arguments ("
1546  << matcherArguments.size() << ") for " << matcher;
1547  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1548  return diag;
1549  }
1550  for (auto &&[i, operand, argument] :
1551  llvm::enumerate(operandTypes, matcherArguments)) {
1552  if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1554  emitOpError()
1555  << "does not expect matcher symbol to consume its operand #" << i;
1556  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1557  return diag;
1558  }
1559 
1560  if (implementSameTransformInterface(operand, argument))
1561  continue;
1562 
1564  emitError()
1565  << "mismatching type interfaces for operand and matcher argument #"
1566  << i << " of matcher " << matcher;
1567  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1568  return diag;
1569  }
1570 
1571  // Matcher -> action forwarding.
1572  TypeRange matcherResults = matcherSymbol.getResultTypes();
1573  TypeRange actionArguments = actionSymbol.getArgumentTypes();
1574  if (matcherResults.size() != actionArguments.size()) {
1575  return emitError() << "mismatching number of matcher results and "
1576  "action arguments between "
1577  << matcher << " (" << matcherResults.size() << ") and "
1578  << action << " (" << actionArguments.size() << ")";
1579  }
1580  for (auto &&[i, matcherType, actionType] :
1581  llvm::enumerate(matcherResults, actionArguments)) {
1582  if (implementSameTransformInterface(matcherType, actionType))
1583  continue;
1584 
1585  return emitError() << "mismatching type interfaces for matcher result "
1586  "and action argument #"
1587  << i << "of matcher " << matcher << " and action "
1588  << action;
1589  }
1590 
1591  // Action -> result forwarding.
1592  TypeRange actionResults = actionSymbol.getResultTypes();
1593  auto resultTypes = TypeRange(getResultTypes()).drop_front();
1594  if (actionResults.size() != resultTypes.size()) {
1596  emitError() << "the number of action results ("
1597  << actionResults.size() << ") for " << action
1598  << " doesn't match the number of extra op results ("
1599  << resultTypes.size() << ")";
1600  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1601  return diag;
1602  }
1603  for (auto &&[i, resultType, actionType] :
1604  llvm::enumerate(resultTypes, actionResults)) {
1605  if (implementSameTransformInterface(resultType, actionType))
1606  continue;
1607 
1609  emitError() << "mismatching type interfaces for action result #" << i
1610  << " of action " << action << " and op result";
1611  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1612  return diag;
1613  }
1614  }
1615  return success();
1616 }
1617 
1618 //===----------------------------------------------------------------------===//
1619 // ForeachOp
1620 //===----------------------------------------------------------------------===//
1621 
1623 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1624  transform::TransformResults &results,
1625  transform::TransformState &state) {
1626  // We store the payloads before executing the body as ops may be removed from
1627  // the mapping by the TrackingRewriter while iteration is in progress.
1629  detail::prepareValueMappings(payloads, getTargets(), state);
1630  size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1631  bool withZipShortest = getWithZipShortest();
1632 
1633  // In case of `zip_shortest`, set the number of iterations to the
1634  // smallest payload in the targets.
1635  if (withZipShortest) {
1636  numIterations =
1637  llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
1638  const SmallVector<MappedValue> &B) {
1639  return A.size() < B.size();
1640  })->size();
1641 
1642  for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1643  payloads[argIdx].resize(numIterations);
1644  }
1645 
1646  // As we will be "zipping" over them, check all payloads have the same size.
1647  // `zip_shortest` adjusts all payloads to the same size, so skip this check
1648  // when true.
1649  for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1650  argIdx++) {
1651  if (payloads[argIdx].size() != numIterations) {
1652  return emitSilenceableError()
1653  << "prior targets' payload size (" << numIterations
1654  << ") differs from payload size (" << payloads[argIdx].size()
1655  << ") of target " << getTargets()[argIdx];
1656  }
1657  }
1658 
1659  // Start iterating, indexing into payloads to obtain the right arguments to
1660  // call the body with - each slice of payloads at the same argument index
1661  // corresponding to a tuple to use as the body's block arguments.
1662  ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1663  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1664  for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1665  auto scope = state.make_region_scope(getBody());
1666  // Set up arguments to the region's block.
1667  for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1668  MappedValue argument = payloads[argIdx][iterIdx];
1669  // Note that each blockArg's handle gets associated with just a single
1670  // element from the corresponding target's payload.
1671  if (failed(state.mapBlockArgument(blockArg, {argument})))
1673  }
1674 
1675  // Execute loop body.
1676  for (Operation &transform : getBody().front().without_terminator()) {
1677  DiagnosedSilenceableFailure result = state.applyTransform(
1678  llvm::cast<transform::TransformOpInterface>(transform));
1679  if (!result.succeeded())
1680  return result;
1681  }
1682 
1683  // Append yielded payloads to corresponding results from prior iterations.
1684  OperandRange yieldOperands = getYieldOp().getOperands();
1685  for (auto &&[result, yieldOperand, resTuple] :
1686  llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1687  // NB: each iteration we add any number of ops/vals/params to a result.
1688  if (isa<TransformHandleTypeInterface>(result.getType()))
1689  llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1690  else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1691  llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1692  else if (isa<TransformParamTypeInterface>(result.getType()))
1693  llvm::append_range(resTuple, state.getParams(yieldOperand));
1694  else
1695  assert(false && "unhandled handle type");
1696  }
1697 
1698  // Associate the accumulated result payloads to the op's actual results.
1699  for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1700  results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1701 
1703 }
1704 
1705 void transform::ForeachOp::getEffects(
1706  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1707  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1708  // arity errors, this method might get called before/in absence of `verify()`.
1709  for (auto &&[target, blockArg] :
1710  llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1711  BlockArgument blockArgument = blockArg;
1712  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1713  return isHandleConsumed(blockArgument,
1714  cast<TransformOpInterface>(&op));
1715  })) {
1716  consumesHandle(target, effects);
1717  } else {
1718  onlyReadsHandle(target, effects);
1719  }
1720  }
1721 
1722  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1723  return doesModifyPayload(cast<TransformOpInterface>(&op));
1724  })) {
1725  modifiesPayload(effects);
1726  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1727  return doesReadPayload(cast<TransformOpInterface>(&op));
1728  })) {
1729  onlyReadsPayload(effects);
1730  }
1731 
1732  producesHandle(getOperation()->getOpResults(), effects);
1733 }
1734 
1735 void transform::ForeachOp::getSuccessorRegions(
1736  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1737  Region *bodyRegion = &getBody();
1738  if (point.isParent()) {
1739  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1740  return;
1741  }
1742 
1743  // Branch back to the region or the parent.
1744  assert(point == getBody() && "unexpected region index");
1745  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1746  regions.emplace_back();
1747 }
1748 
1750 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1751  // Each block argument handle is mapped to a subset (one op to be precise)
1752  // of the payload of the corresponding `targets` operand of ForeachOp.
1753  assert(point == getBody() && "unexpected region index");
1754  return getOperation()->getOperands();
1755 }
1756 
1757 transform::YieldOp transform::ForeachOp::getYieldOp() {
1758  return cast<transform::YieldOp>(getBody().front().getTerminator());
1759 }
1760 
1761 LogicalResult transform::ForeachOp::verify() {
1762  for (auto [targetOpt, bodyArgOpt] :
1763  llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1764  if (!targetOpt || !bodyArgOpt)
1765  return emitOpError() << "expects the same number of targets as the body "
1766  "has block arguments";
1767  if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1768  return emitOpError(
1769  "expects co-indexed targets and the body's "
1770  "block arguments to have the same op/value/param type");
1771  }
1772 
1773  for (auto [resultOpt, yieldOperandOpt] :
1774  llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1775  if (!resultOpt || !yieldOperandOpt)
1776  return emitOpError() << "expects the same number of results as the "
1777  "yield terminator has operands";
1778  if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1779  return emitOpError("expects co-indexed results and yield "
1780  "operands to have the same op/value/param type");
1781  }
1782 
1783  return success();
1784 }
1785 
1786 //===----------------------------------------------------------------------===//
1787 // GetParentOp
1788 //===----------------------------------------------------------------------===//
1789 
1791 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1792  transform::TransformResults &results,
1793  transform::TransformState &state) {
1794  SmallVector<Operation *> parents;
1795  DenseSet<Operation *> resultSet;
1796  for (Operation *target : state.getPayloadOps(getTarget())) {
1797  Operation *parent = target;
1798  for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1799  parent = parent->getParentOp();
1800  while (parent) {
1801  bool checkIsolatedFromAbove =
1802  !getIsolatedFromAbove() ||
1804  bool checkOpName = !getOpName().has_value() ||
1805  parent->getName().getStringRef() == *getOpName();
1806  if (checkIsolatedFromAbove && checkOpName)
1807  break;
1808  parent = parent->getParentOp();
1809  }
1810  if (!parent) {
1811  if (getAllowEmptyResults()) {
1812  results.set(llvm::cast<OpResult>(getResult()), parents);
1814  }
1816  emitSilenceableError()
1817  << "could not find a parent op that matches all requirements";
1818  diag.attachNote(target->getLoc()) << "target op";
1819  return diag;
1820  }
1821  }
1822  if (getDeduplicate()) {
1823  if (resultSet.insert(parent).second)
1824  parents.push_back(parent);
1825  } else {
1826  parents.push_back(parent);
1827  }
1828  }
1829  results.set(llvm::cast<OpResult>(getResult()), parents);
1831 }
1832 
1833 //===----------------------------------------------------------------------===//
1834 // GetConsumersOfResult
1835 //===----------------------------------------------------------------------===//
1836 
1838 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1839  transform::TransformResults &results,
1840  transform::TransformState &state) {
1841  int64_t resultNumber = getResultNumber();
1842  auto payloadOps = state.getPayloadOps(getTarget());
1843  if (std::empty(payloadOps)) {
1844  results.set(cast<OpResult>(getResult()), {});
1846  }
1847  if (!llvm::hasSingleElement(payloadOps))
1848  return emitDefiniteFailure()
1849  << "handle must be mapped to exactly one payload op";
1850 
1851  Operation *target = *payloadOps.begin();
1852  if (target->getNumResults() <= resultNumber)
1853  return emitDefiniteFailure() << "result number overflow";
1854  results.set(llvm::cast<OpResult>(getResult()),
1855  llvm::to_vector(target->getResult(resultNumber).getUsers()));
1857 }
1858 
1859 //===----------------------------------------------------------------------===//
1860 // GetDefiningOp
1861 //===----------------------------------------------------------------------===//
1862 
1864 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1865  transform::TransformResults &results,
1866  transform::TransformState &state) {
1867  SmallVector<Operation *> definingOps;
1868  for (Value v : state.getPayloadValues(getTarget())) {
1869  if (llvm::isa<BlockArgument>(v)) {
1871  emitSilenceableError() << "cannot get defining op of block argument";
1872  diag.attachNote(v.getLoc()) << "target value";
1873  return diag;
1874  }
1875  definingOps.push_back(v.getDefiningOp());
1876  }
1877  results.set(llvm::cast<OpResult>(getResult()), definingOps);
1879 }
1880 
1881 //===----------------------------------------------------------------------===//
1882 // GetProducerOfOperand
1883 //===----------------------------------------------------------------------===//
1884 
1886 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1887  transform::TransformResults &results,
1888  transform::TransformState &state) {
1889  int64_t operandNumber = getOperandNumber();
1890  SmallVector<Operation *> producers;
1891  for (Operation *target : state.getPayloadOps(getTarget())) {
1892  Operation *producer =
1893  target->getNumOperands() <= operandNumber
1894  ? nullptr
1895  : target->getOperand(operandNumber).getDefiningOp();
1896  if (!producer) {
1898  emitSilenceableError()
1899  << "could not find a producer for operand number: " << operandNumber
1900  << " of " << *target;
1901  diag.attachNote(target->getLoc()) << "target op";
1902  return diag;
1903  }
1904  producers.push_back(producer);
1905  }
1906  results.set(llvm::cast<OpResult>(getResult()), producers);
1908 }
1909 
1910 //===----------------------------------------------------------------------===//
1911 // GetOperandOp
1912 //===----------------------------------------------------------------------===//
1913 
1915 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1916  transform::TransformResults &results,
1917  transform::TransformState &state) {
1918  SmallVector<Value> operands;
1919  for (Operation *target : state.getPayloadOps(getTarget())) {
1920  SmallVector<int64_t> operandPositions;
1922  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1923  target->getNumOperands(), operandPositions);
1924  if (diag.isSilenceableFailure()) {
1925  diag.attachNote(target->getLoc())
1926  << "while considering positions of this payload operation";
1927  return diag;
1928  }
1929  llvm::append_range(operands,
1930  llvm::map_range(operandPositions, [&](int64_t pos) {
1931  return target->getOperand(pos);
1932  }));
1933  }
1934  results.setValues(cast<OpResult>(getResult()), operands);
1936 }
1937 
1938 LogicalResult transform::GetOperandOp::verify() {
1939  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1940  getIsInverted(), getIsAll());
1941 }
1942 
1943 //===----------------------------------------------------------------------===//
1944 // GetResultOp
1945 //===----------------------------------------------------------------------===//
1946 
1948 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1949  transform::TransformResults &results,
1950  transform::TransformState &state) {
1951  SmallVector<Value> opResults;
1952  for (Operation *target : state.getPayloadOps(getTarget())) {
1953  SmallVector<int64_t> resultPositions;
1955  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1956  target->getNumResults(), resultPositions);
1957  if (diag.isSilenceableFailure()) {
1958  diag.attachNote(target->getLoc())
1959  << "while considering positions of this payload operation";
1960  return diag;
1961  }
1962  llvm::append_range(opResults,
1963  llvm::map_range(resultPositions, [&](int64_t pos) {
1964  return target->getResult(pos);
1965  }));
1966  }
1967  results.setValues(cast<OpResult>(getResult()), opResults);
1969 }
1970 
1971 LogicalResult transform::GetResultOp::verify() {
1972  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1973  getIsInverted(), getIsAll());
1974 }
1975 
1976 //===----------------------------------------------------------------------===//
1977 // GetTypeOp
1978 //===----------------------------------------------------------------------===//
1979 
1980 void transform::GetTypeOp::getEffects(
1981  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1982  onlyReadsHandle(getValueMutable(), effects);
1983  producesHandle(getOperation()->getOpResults(), effects);
1984  onlyReadsPayload(effects);
1985 }
1986 
1988 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1989  transform::TransformResults &results,
1990  transform::TransformState &state) {
1991  SmallVector<Attribute> params;
1992  for (Value value : state.getPayloadValues(getValue())) {
1993  Type type = value.getType();
1994  if (getElemental()) {
1995  if (auto shaped = dyn_cast<ShapedType>(type)) {
1996  type = shaped.getElementType();
1997  }
1998  }
1999  params.push_back(TypeAttr::get(type));
2000  }
2001  results.setParams(cast<OpResult>(getResult()), params);
2003 }
2004 
2005 //===----------------------------------------------------------------------===//
2006 // IncludeOp
2007 //===----------------------------------------------------------------------===//
2008 
2009 /// Applies the transform ops contained in `block`. Maps `results` to the same
2010 /// values as the operands of the block terminator.
2012 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
2014  transform::TransformResults &results) {
2015  // Apply the sequenced ops one by one.
2016  for (Operation &transform : block.without_terminator()) {
2018  state.applyTransform(cast<transform::TransformOpInterface>(transform));
2019  if (result.isDefiniteFailure())
2020  return result;
2021 
2022  if (result.isSilenceableFailure()) {
2023  if (mode == transform::FailurePropagationMode::Propagate) {
2024  // Propagate empty results in case of early exit.
2025  forwardEmptyOperands(&block, state, results);
2026  return result;
2027  }
2028  (void)result.silence();
2029  }
2030  }
2031 
2032  // Forward the operation mapping for values yielded from the sequence to the
2033  // values produced by the sequence op.
2034  transform::detail::forwardTerminatorOperands(&block, state, results);
2036 }
2037 
2039 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
2040  transform::TransformResults &results,
2041  transform::TransformState &state) {
2042  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2043  getOperation(), getTarget());
2044  assert(callee && "unverified reference to unknown symbol");
2045 
2046  if (callee.isExternal())
2047  return emitDefiniteFailure() << "unresolved external named sequence";
2048 
2049  // Map operands to block arguments.
2051  detail::prepareValueMappings(mappings, getOperands(), state);
2052  auto scope = state.make_region_scope(callee.getBody());
2053  for (auto &&[arg, map] :
2054  llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
2055  if (failed(state.mapBlockArgument(arg, map)))
2057  }
2058 
2060  callee.getBody().front(), getFailurePropagationMode(), state, results);
2061  mappings.clear();
2063  mappings, callee.getBody().front().getTerminator()->getOperands(), state);
2064  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
2065  results.setMappedValues(result, mapping);
2066  return result;
2067 }
2068 
2070 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
2071 
2072 void transform::IncludeOp::getEffects(
2073  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2074  // Always mark as modifying the payload.
2075  // TODO: a mechanism to annotate effects on payload. Even when all handles are
2076  // only read, the payload may still be modified, so we currently stay on the
2077  // conservative side and always indicate modification. This may prevent some
2078  // code reordering.
2079  modifiesPayload(effects);
2080 
2081  // Results are always produced.
2082  producesHandle(getOperation()->getOpResults(), effects);
2083 
2084  // Adds default effects to operands and results. This will be added if
2085  // preconditions fail so the trait verifier doesn't complain about missing
2086  // effects and the real precondition failure is reported later on.
2087  auto defaultEffects = [&] {
2088  onlyReadsHandle(getOperation()->getOpOperands(), effects);
2089  };
2090 
2091  // Bail if the callee is unknown. This may run as part of the verification
2092  // process before we verified the validity of the callee or of this op.
2093  auto target =
2094  getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
2095  if (!target)
2096  return defaultEffects();
2097  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
2098  getOperation(), getTarget());
2099  if (!callee)
2100  return defaultEffects();
2101  DiagnosedSilenceableFailure earlyVerifierResult =
2102  verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
2103  if (!earlyVerifierResult.succeeded()) {
2104  (void)earlyVerifierResult.silence();
2105  return defaultEffects();
2106  }
2107 
2108  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
2109  if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
2110  consumesHandle(getOperation()->getOpOperand(i), effects);
2111  else
2112  onlyReadsHandle(getOperation()->getOpOperand(i), effects);
2113  }
2114 }
2115 
2116 LogicalResult
2117 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2118  // Access through indirection and do additional checking because this may be
2119  // running before the main op verifier.
2120  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
2121  if (!targetAttr)
2122  return emitOpError() << "expects a 'target' symbol reference attribute";
2123 
2124  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
2125  *this, targetAttr);
2126  if (!target)
2127  return emitOpError() << "does not reference a named transform sequence";
2128 
2129  FunctionType fnType = target.getFunctionType();
2130  if (fnType.getNumInputs() != getNumOperands())
2131  return emitError("incorrect number of operands for callee");
2132 
2133  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
2134  if (getOperand(i).getType() != fnType.getInput(i)) {
2135  return emitOpError("operand type mismatch: expected operand type ")
2136  << fnType.getInput(i) << ", but provided "
2137  << getOperand(i).getType() << " for operand number " << i;
2138  }
2139  }
2140 
2141  if (fnType.getNumResults() != getNumResults())
2142  return emitError("incorrect number of results for callee");
2143 
2144  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
2145  Type resultType = getResult(i).getType();
2146  Type funcType = fnType.getResult(i);
2147  if (!implementSameTransformInterface(resultType, funcType)) {
2148  return emitOpError() << "type of result #" << i
2149  << " must implement the same transform dialect "
2150  "interface as the corresponding callee result";
2151  }
2152  }
2153 
2155  cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
2156  /*alsoVerifyInternal=*/true)
2157  .checkAndReport();
2158 }
2159 
2160 //===----------------------------------------------------------------------===//
2161 // MatchOperationEmptyOp
2162 //===----------------------------------------------------------------------===//
2163 
2164 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
2165  ::std::optional<::mlir::Operation *> maybeCurrent,
2167  if (!maybeCurrent.has_value()) {
2168  LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp success";
2170  }
2171  LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp failure";
2172  return emitSilenceableError() << "operation is not empty";
2173 }
2174 
2175 //===----------------------------------------------------------------------===//
2176 // MatchOperationNameOp
2177 //===----------------------------------------------------------------------===//
2178 
2179 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
2180  Operation *current, transform::TransformResults &results,
2181  transform::TransformState &state) {
2182  StringRef currentOpName = current->getName().getStringRef();
2183  for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
2184  if (acceptedAttr.getValue() == currentOpName)
2186  }
2187  return emitSilenceableError() << "wrong operation name";
2188 }
2189 
2190 //===----------------------------------------------------------------------===//
2191 // MatchParamCmpIOp
2192 //===----------------------------------------------------------------------===//
2193 
2195 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
2196  transform::TransformResults &results,
2197  transform::TransformState &state) {
2198  auto signedAPIntAsString = [&](const APInt &value) {
2199  std::string str;
2200  llvm::raw_string_ostream os(str);
2201  value.print(os, /*isSigned=*/true);
2202  return str;
2203  };
2204 
2205  ArrayRef<Attribute> params = state.getParams(getParam());
2206  ArrayRef<Attribute> references = state.getParams(getReference());
2207 
2208  if (params.size() != references.size()) {
2209  return emitSilenceableError()
2210  << "parameters have different payload lengths (" << params.size()
2211  << " vs " << references.size() << ")";
2212  }
2213 
2214  for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
2215  auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
2216  auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
2217  if (!intAttr || !refAttr) {
2218  return emitDefiniteFailure()
2219  << "non-integer parameter value not expected";
2220  }
2221  if (intAttr.getType() != refAttr.getType()) {
2222  return emitDefiniteFailure()
2223  << "mismatching integer attribute types in parameter #" << i;
2224  }
2225  APInt value = intAttr.getValue();
2226  APInt refValue = refAttr.getValue();
2227 
2228  // TODO: this copy will not be necessary in C++20.
2229  int64_t position = i;
2230  auto reportError = [&](StringRef direction) {
2232  emitSilenceableError() << "expected parameter to be " << direction
2233  << " " << signedAPIntAsString(refValue)
2234  << ", got " << signedAPIntAsString(value);
2235  diag.attachNote(getParam().getLoc())
2236  << "value # " << position
2237  << " associated with the parameter defined here";
2238  return diag;
2239  };
2240 
2241  switch (getPredicate()) {
2242  case MatchCmpIPredicate::eq:
2243  if (value.eq(refValue))
2244  break;
2245  return reportError("equal to");
2246  case MatchCmpIPredicate::ne:
2247  if (value.ne(refValue))
2248  break;
2249  return reportError("not equal to");
2250  case MatchCmpIPredicate::lt:
2251  if (value.slt(refValue))
2252  break;
2253  return reportError("less than");
2254  case MatchCmpIPredicate::le:
2255  if (value.sle(refValue))
2256  break;
2257  return reportError("less than or equal to");
2258  case MatchCmpIPredicate::gt:
2259  if (value.sgt(refValue))
2260  break;
2261  return reportError("greater than");
2262  case MatchCmpIPredicate::ge:
2263  if (value.sge(refValue))
2264  break;
2265  return reportError("greater than or equal to");
2266  }
2267  }
2269 }
2270 
2271 void transform::MatchParamCmpIOp::getEffects(
2272  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2273  onlyReadsHandle(getParamMutable(), effects);
2274  onlyReadsHandle(getReferenceMutable(), effects);
2275 }
2276 
2277 //===----------------------------------------------------------------------===//
2278 // ParamConstantOp
2279 //===----------------------------------------------------------------------===//
2280 
2282 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2283  transform::TransformResults &results,
2284  transform::TransformState &state) {
2285  results.setParams(cast<OpResult>(getParam()), {getValue()});
2287 }
2288 
2289 //===----------------------------------------------------------------------===//
2290 // MergeHandlesOp
2291 //===----------------------------------------------------------------------===//
2292 
2294 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2295  transform::TransformResults &results,
2296  transform::TransformState &state) {
2297  ValueRange handles = getHandles();
2298  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2299  SmallVector<Operation *> operations;
2300  for (Value operand : handles)
2301  llvm::append_range(operations, state.getPayloadOps(operand));
2302  if (!getDeduplicate()) {
2303  results.set(llvm::cast<OpResult>(getResult()), operations);
2305  }
2306 
2307  SetVector<Operation *> uniqued(llvm::from_range, operations);
2308  results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2310  }
2311 
2312  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2313  SmallVector<Attribute> attrs;
2314  for (Value attribute : handles)
2315  llvm::append_range(attrs, state.getParams(attribute));
2316  if (!getDeduplicate()) {
2317  results.setParams(cast<OpResult>(getResult()), attrs);
2319  }
2320 
2321  SetVector<Attribute> uniqued(llvm::from_range, attrs);
2322  results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2324  }
2325 
2326  assert(
2327  llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2328  "expected value handle type");
2329  SmallVector<Value> payloadValues;
2330  for (Value value : handles)
2331  llvm::append_range(payloadValues, state.getPayloadValues(value));
2332  if (!getDeduplicate()) {
2333  results.setValues(cast<OpResult>(getResult()), payloadValues);
2335  }
2336 
2337  SetVector<Value> uniqued(llvm::from_range, payloadValues);
2338  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2340 }
2341 
2342 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2343  // Handles may be the same if deduplicating is enabled.
2344  return getDeduplicate();
2345 }
2346 
2347 void transform::MergeHandlesOp::getEffects(
2348  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2349  onlyReadsHandle(getHandlesMutable(), effects);
2350  producesHandle(getOperation()->getOpResults(), effects);
2351 
2352  // There are no effects on the Payload IR as this is only a handle
2353  // manipulation.
2354 }
2355 
2356 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2357  if (getDeduplicate() || getHandles().size() != 1)
2358  return {};
2359 
2360  // If deduplication is not required and there is only one operand, it can be
2361  // used directly instead of merging.
2362  return getHandles().front();
2363 }
2364 
2365 //===----------------------------------------------------------------------===//
2366 // NamedSequenceOp
2367 //===----------------------------------------------------------------------===//
2368 
2370 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2371  transform::TransformResults &results,
2372  transform::TransformState &state) {
2373  if (isExternal())
2374  return emitDefiniteFailure() << "unresolved external named sequence";
2375 
2376  // Map the entry block argument to the list of operations.
2377  // Note: this is the same implementation as PossibleTopLevelTransformOp but
2378  // without attaching the interface / trait since that is tailored to a
2379  // dangling top-level op that does not get "called".
2380  auto scope = state.make_region_scope(getBody());
2382  state, this->getOperation(), getBody())))
2384 
2385  return applySequenceBlock(getBody().front(),
2386  FailurePropagationMode::Propagate, state, results);
2387 }
2388 
2389 void transform::NamedSequenceOp::getEffects(
2390  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2391 
2393  OperationState &result) {
2395  parser, result, /*allowVariadic=*/false,
2396  getFunctionTypeAttrName(result.name),
2397  [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2399  std::string &) { return builder.getFunctionType(inputs, results); },
2400  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2401 }
2402 
2405  printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2406  getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2407  getResAttrsAttrName());
2408 }
2409 
2410 /// Verifies that a symbol function-like transform dialect operation has the
2411 /// signature and the terminator that have conforming types, i.e., types
2412 /// implementing the same transform dialect type interface. If `allowExternal`
2413 /// is set, allow external symbols (declarations) and don't check the terminator
2414 /// as it may not exist.
2416 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2417  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2420  << "cannot be defined inside another transform op";
2421  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2422  return diag;
2423  }
2424 
2425  if (op.isExternal() || op.getFunctionBody().empty()) {
2426  if (allowExternal)
2428 
2429  return emitSilenceableFailure(op) << "cannot be external";
2430  }
2431 
2432  if (op.getFunctionBody().front().empty())
2433  return emitSilenceableFailure(op) << "expected a non-empty body block";
2434 
2435  Operation *terminator = &op.getFunctionBody().front().back();
2436  if (!isa<transform::YieldOp>(terminator)) {
2438  << "expected '"
2439  << transform::YieldOp::getOperationName()
2440  << "' as terminator";
2441  diag.attachNote(terminator->getLoc()) << "terminator";
2442  return diag;
2443  }
2444 
2445  if (terminator->getNumOperands() != op.getResultTypes().size()) {
2446  return emitSilenceableFailure(terminator)
2447  << "expected terminator to have as many operands as the parent op "
2448  "has results";
2449  }
2450  for (auto [i, operandType, resultType] : llvm::zip_equal(
2451  llvm::seq<unsigned>(0, terminator->getNumOperands()),
2452  terminator->getOperands().getType(), op.getResultTypes())) {
2453  if (operandType == resultType)
2454  continue;
2455  return emitSilenceableFailure(terminator)
2456  << "the type of the terminator operand #" << i
2457  << " must match the type of the corresponding parent op result ("
2458  << operandType << " vs " << resultType << ")";
2459  }
2460 
2462 }
2463 
2464 /// Verification of a NamedSequenceOp. This does not report the error
2465 /// immediately, so it can be used to check for op's well-formedness before the
2466 /// verifier runs, e.g., during trait verification.
2468 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2469  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2470  if (!parent->getAttr(
2471  transform::TransformDialect::kWithNamedSequenceAttrName)) {
2474  << "expects the parent symbol table to have the '"
2475  << transform::TransformDialect::kWithNamedSequenceAttrName
2476  << "' attribute";
2477  diag.attachNote(parent->getLoc()) << "symbol table operation";
2478  return diag;
2479  }
2480  }
2481 
2482  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2485  << "cannot be defined inside another transform op";
2486  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2487  return diag;
2488  }
2489 
2490  if (op.isExternal() || op.getBody().empty())
2491  return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2492  emitWarnings);
2493 
2494  if (op.getBody().front().empty())
2495  return emitSilenceableFailure(op) << "expected a non-empty body block";
2496 
2497  Operation *terminator = &op.getBody().front().back();
2498  if (!isa<transform::YieldOp>(terminator)) {
2500  << "expected '"
2501  << transform::YieldOp::getOperationName()
2502  << "' as terminator";
2503  diag.attachNote(terminator->getLoc()) << "terminator";
2504  return diag;
2505  }
2506 
2507  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2508  return emitSilenceableFailure(terminator)
2509  << "expected terminator to have as many operands as the parent op "
2510  "has results";
2511  }
2512  for (auto [i, operandType, resultType] :
2513  llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2514  terminator->getOperands().getType(),
2515  op.getFunctionType().getResults())) {
2516  if (operandType == resultType)
2517  continue;
2518  return emitSilenceableFailure(terminator)
2519  << "the type of the terminator operand #" << i
2520  << " must match the type of the corresponding parent op result ("
2521  << operandType << " vs " << resultType << ")";
2522  }
2523 
2524  auto funcOp = cast<FunctionOpInterface>(*op);
2526  verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2527  if (!diag.succeeded())
2528  return diag;
2529 
2530  return verifyYieldingSingleBlockOp(funcOp,
2531  /*allowExternal=*/true);
2532 }
2533 
2534 LogicalResult transform::NamedSequenceOp::verify() {
2535  // Actual verification happens in a separate function for reusability.
2536  return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2537 }
2538 
2539 template <typename FnTy>
2540 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2541  Type bbArgType, TypeRange extraBindingTypes,
2542  FnTy bodyBuilder) {
2543  SmallVector<Type> types;
2544  types.reserve(1 + extraBindingTypes.size());
2545  types.push_back(bbArgType);
2546  llvm::append_range(types, extraBindingTypes);
2547 
2548  OpBuilder::InsertionGuard guard(builder);
2549  Region *region = state.regions.back().get();
2550  Block *bodyBlock =
2551  builder.createBlock(region, region->begin(), types,
2552  SmallVector<Location>(types.size(), state.location));
2553 
2554  // Populate body.
2555  builder.setInsertionPointToStart(bodyBlock);
2556  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2557  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2558  } else {
2559  bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2560  bodyBlock->getArguments().drop_front());
2561  }
2562 }
2563 
2564 void transform::NamedSequenceOp::build(OpBuilder &builder,
2565  OperationState &state, StringRef symName,
2566  Type rootType, TypeRange resultTypes,
2567  SequenceBodyBuilderFn bodyBuilder,
2569  ArrayRef<DictionaryAttr> argAttrs) {
2570  state.addAttribute(SymbolTable::getSymbolAttrName(),
2571  builder.getStringAttr(symName));
2572  state.addAttribute(getFunctionTypeAttrName(state.name),
2574  rootType, resultTypes)));
2575  state.attributes.append(attrs.begin(), attrs.end());
2576  state.addRegion();
2577 
2578  buildSequenceBody(builder, state, rootType,
2579  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2580 }
2581 
2582 //===----------------------------------------------------------------------===//
2583 // NumAssociationsOp
2584 //===----------------------------------------------------------------------===//
2585 
2587 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2588  transform::TransformResults &results,
2589  transform::TransformState &state) {
2590  size_t numAssociations =
2592  .Case([&](TransformHandleTypeInterface opHandle) {
2593  return llvm::range_size(state.getPayloadOps(getHandle()));
2594  })
2595  .Case([&](TransformValueHandleTypeInterface valueHandle) {
2596  return llvm::range_size(state.getPayloadValues(getHandle()));
2597  })
2598  .Case([&](TransformParamTypeInterface param) {
2599  return llvm::range_size(state.getParams(getHandle()));
2600  })
2601  .Default([](Type) {
2602  llvm_unreachable("unknown kind of transform dialect type");
2603  return 0;
2604  });
2605  results.setParams(cast<OpResult>(getNum()),
2606  rewriter.getI64IntegerAttr(numAssociations));
2608 }
2609 
2610 LogicalResult transform::NumAssociationsOp::verify() {
2611  // Verify that the result type accepts an i64 attribute as payload.
2612  auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2613  return resultType
2614  .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2615  .checkAndReport();
2616 }
2617 
2618 //===----------------------------------------------------------------------===//
2619 // SelectOp
2620 //===----------------------------------------------------------------------===//
2621 
2623 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2624  transform::TransformResults &results,
2625  transform::TransformState &state) {
2626  SmallVector<Operation *> result;
2627  auto payloadOps = state.getPayloadOps(getTarget());
2628  for (Operation *op : payloadOps) {
2629  if (op->getName().getStringRef() == getOpName())
2630  result.push_back(op);
2631  }
2632  results.set(cast<OpResult>(getResult()), result);
2634 }
2635 
2636 //===----------------------------------------------------------------------===//
2637 // SplitHandleOp
2638 //===----------------------------------------------------------------------===//
2639 
2640 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2641  Value target, int64_t numResultHandles) {
2642  result.addOperands(target);
2643  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2644 }
2645 
2647 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2648  transform::TransformResults &results,
2649  transform::TransformState &state) {
2650  int64_t numPayloads =
2652  .Case<TransformHandleTypeInterface>([&](auto x) {
2653  return llvm::range_size(state.getPayloadOps(getHandle()));
2654  })
2655  .Case<TransformValueHandleTypeInterface>([&](auto x) {
2656  return llvm::range_size(state.getPayloadValues(getHandle()));
2657  })
2658  .Case<TransformParamTypeInterface>([&](auto x) {
2659  return llvm::range_size(state.getParams(getHandle()));
2660  })
2661  .Default([](auto x) {
2662  llvm_unreachable("unknown transform dialect type interface");
2663  return -1;
2664  });
2665 
2666  auto produceNumOpsError = [&]() {
2667  return emitSilenceableError()
2668  << getHandle() << " expected to contain " << this->getNumResults()
2669  << " payloads but it contains " << numPayloads << " payloads";
2670  };
2671 
2672  // Fail if there are more payload ops than results and no overflow result was
2673  // specified.
2674  if (numPayloads > getNumResults() && !getOverflowResult().has_value())
2675  return produceNumOpsError();
2676 
2677  // Fail if there are more results than payload ops. Unless:
2678  // - "fail_on_payload_too_small" is set to "false", or
2679  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2680  if (numPayloads < getNumResults() && getFailOnPayloadTooSmall() &&
2681  (numPayloads != 0 || !getPassThroughEmptyHandle()))
2682  return produceNumOpsError();
2683 
2684  // Distribute payloads.
2685  SmallVector<SmallVector<MappedValue, 1>> resultHandles(getNumResults(), {});
2686  if (getOverflowResult())
2687  resultHandles[*getOverflowResult()].reserve(numPayloads - getNumResults());
2688 
2689  auto container = [&]() {
2690  if (isa<TransformHandleTypeInterface>(getHandle().getType())) {
2691  return llvm::map_to_vector(
2692  state.getPayloadOps(getHandle()),
2693  [](Operation *op) -> MappedValue { return op; });
2694  }
2695  if (isa<TransformValueHandleTypeInterface>(getHandle().getType())) {
2696  return llvm::map_to_vector(state.getPayloadValues(getHandle()),
2697  [](Value v) -> MappedValue { return v; });
2698  }
2699  assert(isa<TransformParamTypeInterface>(getHandle().getType()) &&
2700  "unsupported kind of transform dialect type");
2701  return llvm::map_to_vector(state.getParams(getHandle()),
2702  [](Attribute a) -> MappedValue { return a; });
2703  }();
2704 
2705  for (auto &&en : llvm::enumerate(container)) {
2706  int64_t resultNum = en.index();
2707  if (resultNum >= getNumResults())
2708  resultNum = *getOverflowResult();
2709  resultHandles[resultNum].push_back(en.value());
2710  }
2711 
2712  // Set transform op results.
2713  for (auto &&it : llvm::enumerate(resultHandles))
2714  results.setMappedValues(llvm::cast<OpResult>(getResult(it.index())),
2715  it.value());
2716 
2718 }
2719 
2720 void transform::SplitHandleOp::getEffects(
2721  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2722  onlyReadsHandle(getHandleMutable(), effects);
2723  producesHandle(getOperation()->getOpResults(), effects);
2724  // There are no effects on the Payload IR as this is only a handle
2725  // manipulation.
2726 }
2727 
2728 LogicalResult transform::SplitHandleOp::verify() {
2729  if (getOverflowResult().has_value() &&
2730  !(*getOverflowResult() < getNumResults()))
2731  return emitOpError("overflow_result is not a valid result index");
2732 
2733  for (Type resultType : getResultTypes()) {
2734  if (implementSameTransformInterface(getHandle().getType(), resultType))
2735  continue;
2736 
2737  return emitOpError("expects result types to implement the same transform "
2738  "interface as the operand type");
2739  }
2740 
2741  return success();
2742 }
2743 
2744 //===----------------------------------------------------------------------===//
2745 // ReplicateOp
2746 //===----------------------------------------------------------------------===//
2747 
2749 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2750  transform::TransformResults &results,
2751  transform::TransformState &state) {
2752  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2753  for (const auto &en : llvm::enumerate(getHandles())) {
2754  Value handle = en.value();
2755  if (isa<TransformHandleTypeInterface>(handle.getType())) {
2756  SmallVector<Operation *> current =
2757  llvm::to_vector(state.getPayloadOps(handle));
2758  SmallVector<Operation *> payload;
2759  payload.reserve(numRepetitions * current.size());
2760  for (unsigned i = 0; i < numRepetitions; ++i)
2761  llvm::append_range(payload, current);
2762  results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2763  } else {
2764  assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2765  "expected param type");
2766  ArrayRef<Attribute> current = state.getParams(handle);
2767  SmallVector<Attribute> params;
2768  params.reserve(numRepetitions * current.size());
2769  for (unsigned i = 0; i < numRepetitions; ++i)
2770  llvm::append_range(params, current);
2771  results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2772  params);
2773  }
2774  }
2776 }
2777 
2778 void transform::ReplicateOp::getEffects(
2779  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2780  onlyReadsHandle(getPatternMutable(), effects);
2781  onlyReadsHandle(getHandlesMutable(), effects);
2782  producesHandle(getOperation()->getOpResults(), effects);
2783 }
2784 
2785 //===----------------------------------------------------------------------===//
2786 // SequenceOp
2787 //===----------------------------------------------------------------------===//
2788 
2790 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2791  transform::TransformResults &results,
2792  transform::TransformState &state) {
2793  // Map the entry block argument to the list of operations.
2794  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2795  if (failed(mapBlockArguments(state)))
2797 
2798  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2799  results);
2800 }
2801 
2802 static ParseResult parseSequenceOpOperands(
2803  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2804  Type &rootType,
2805  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2806  SmallVectorImpl<Type> &extraBindingTypes) {
2807  OpAsmParser::UnresolvedOperand rootOperand;
2808  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2809  if (!hasRoot.has_value()) {
2810  root = std::nullopt;
2811  return success();
2812  }
2813  if (failed(hasRoot.value()))
2814  return failure();
2815  root = rootOperand;
2816 
2817  if (succeeded(parser.parseOptionalComma())) {
2818  if (failed(parser.parseOperandList(extraBindings)))
2819  return failure();
2820  }
2821  if (failed(parser.parseColon()))
2822  return failure();
2823 
2824  // The paren is truly optional.
2825  (void)parser.parseOptionalLParen();
2826 
2827  if (failed(parser.parseType(rootType))) {
2828  return failure();
2829  }
2830 
2831  if (!extraBindings.empty()) {
2832  if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2833  return failure();
2834  }
2835 
2836  if (extraBindingTypes.size() != extraBindings.size()) {
2837  return parser.emitError(parser.getNameLoc(),
2838  "expected types to be provided for all operands");
2839  }
2840 
2841  // The paren is truly optional.
2842  (void)parser.parseOptionalRParen();
2843  return success();
2844 }
2845 
2847  Value root, Type rootType,
2848  ValueRange extraBindings,
2849  TypeRange extraBindingTypes) {
2850  if (!root)
2851  return;
2852 
2853  printer << root;
2854  bool hasExtras = !extraBindings.empty();
2855  if (hasExtras) {
2856  printer << ", ";
2857  printer.printOperands(extraBindings);
2858  }
2859 
2860  printer << " : ";
2861  if (hasExtras)
2862  printer << "(";
2863 
2864  printer << rootType;
2865  if (hasExtras)
2866  printer << ", " << llvm::interleaved(extraBindingTypes) << ')';
2867 }
2868 
2869 /// Returns `true` if the given op operand may be consuming the handle value in
2870 /// the Transform IR. That is, if it may have a Free effect on it.
2872  // Conservatively assume the effect being present in absence of the interface.
2873  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2874  if (!iface)
2875  return true;
2876 
2877  return isHandleConsumed(use.get(), iface);
2878 }
2879 
2880 LogicalResult
2882  function_ref<InFlightDiagnostic()> reportError) {
2883  OpOperand *potentialConsumer = nullptr;
2884  for (OpOperand &use : value.getUses()) {
2885  if (!isValueUsePotentialConsumer(use))
2886  continue;
2887 
2888  if (!potentialConsumer) {
2889  potentialConsumer = &use;
2890  continue;
2891  }
2892 
2893  InFlightDiagnostic diag = reportError()
2894  << " has more than one potential consumer";
2895  diag.attachNote(potentialConsumer->getOwner()->getLoc())
2896  << "used here as operand #" << potentialConsumer->getOperandNumber();
2897  diag.attachNote(use.getOwner()->getLoc())
2898  << "used here as operand #" << use.getOperandNumber();
2899  return diag;
2900  }
2901 
2902  return success();
2903 }
2904 
2905 LogicalResult transform::SequenceOp::verify() {
2906  assert(getBodyBlock()->getNumArguments() >= 1 &&
2907  "the number of arguments must have been verified to be more than 1 by "
2908  "PossibleTopLevelTransformOpTrait");
2909 
2910  if (!getRoot() && !getExtraBindings().empty()) {
2911  return emitOpError()
2912  << "does not expect extra operands when used as top-level";
2913  }
2914 
2915  // Check if a block argument has more than one consuming use.
2916  for (BlockArgument arg : getBodyBlock()->getArguments()) {
2917  if (failed(checkDoubleConsume(arg, [this, arg]() {
2918  return (emitOpError() << "block argument #" << arg.getArgNumber());
2919  }))) {
2920  return failure();
2921  }
2922  }
2923 
2924  // Check properties of the nested operations they cannot check themselves.
2925  for (Operation &child : *getBodyBlock()) {
2926  if (!isa<TransformOpInterface>(child) &&
2927  &child != &getBodyBlock()->back()) {
2929  emitOpError()
2930  << "expected children ops to implement TransformOpInterface";
2931  diag.attachNote(child.getLoc()) << "op without interface";
2932  return diag;
2933  }
2934 
2935  for (OpResult result : child.getResults()) {
2936  auto report = [&]() {
2937  return (child.emitError() << "result #" << result.getResultNumber());
2938  };
2939  if (failed(checkDoubleConsume(result, report)))
2940  return failure();
2941  }
2942  }
2943 
2944  if (!getBodyBlock()->mightHaveTerminator())
2945  return emitOpError() << "expects to have a terminator in the body";
2946 
2947  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2948  getOperation()->getResultTypes()) {
2949  InFlightDiagnostic diag = emitOpError()
2950  << "expects the types of the terminator operands "
2951  "to match the types of the result";
2952  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2953  return diag;
2954  }
2955  return success();
2956 }
2957 
2958 void transform::SequenceOp::getEffects(
2959  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2960  getPotentialTopLevelEffects(effects);
2961 }
2962 
2964 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2965  assert(point == getBody() && "unexpected region index");
2966  if (getOperation()->getNumOperands() > 0)
2967  return getOperation()->getOperands();
2968  return OperandRange(getOperation()->operand_end(),
2969  getOperation()->operand_end());
2970 }
2971 
2972 void transform::SequenceOp::getSuccessorRegions(
2973  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2974  if (point.isParent()) {
2975  Region *bodyRegion = &getBody();
2976  regions.emplace_back(bodyRegion, getNumOperands() != 0
2977  ? bodyRegion->getArguments()
2979  return;
2980  }
2981 
2982  assert(point == getBody() && "unexpected region index");
2983  regions.emplace_back(getOperation()->getResults());
2984 }
2985 
2986 void transform::SequenceOp::getRegionInvocationBounds(
2987  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2988  (void)operands;
2989  bounds.emplace_back(1, 1);
2990 }
2991 
2992 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2993  TypeRange resultTypes,
2994  FailurePropagationMode failurePropagationMode,
2995  Value root,
2996  SequenceBodyBuilderFn bodyBuilder) {
2997  build(builder, state, resultTypes, failurePropagationMode, root,
2998  /*extra_bindings=*/ValueRange());
2999  Type bbArgType = root.getType();
3000  buildSequenceBody(builder, state, bbArgType,
3001  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3002 }
3003 
3004 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3005  TypeRange resultTypes,
3006  FailurePropagationMode failurePropagationMode,
3007  Value root, ValueRange extraBindings,
3008  SequenceBodyBuilderArgsFn bodyBuilder) {
3009  build(builder, state, resultTypes, failurePropagationMode, root,
3010  extraBindings);
3011  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
3012  bodyBuilder);
3013 }
3014 
3015 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3016  TypeRange resultTypes,
3017  FailurePropagationMode failurePropagationMode,
3018  Type bbArgType,
3019  SequenceBodyBuilderFn bodyBuilder) {
3020  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3021  /*extra_bindings=*/ValueRange());
3022  buildSequenceBody(builder, state, bbArgType,
3023  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
3024 }
3025 
3026 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
3027  TypeRange resultTypes,
3028  FailurePropagationMode failurePropagationMode,
3029  Type bbArgType, TypeRange extraBindingTypes,
3030  SequenceBodyBuilderArgsFn bodyBuilder) {
3031  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
3032  /*extra_bindings=*/ValueRange());
3033  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
3034 }
3035 
3036 //===----------------------------------------------------------------------===//
3037 // PrintOp
3038 //===----------------------------------------------------------------------===//
3039 
3040 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3041  StringRef name) {
3042  if (!name.empty())
3043  result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
3044 }
3045 
3046 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
3047  Value target, StringRef name) {
3048  result.addOperands({target});
3049  build(builder, result, name);
3050 }
3051 
3053 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
3054  transform::TransformResults &results,
3055  transform::TransformState &state) {
3056  llvm::outs() << "[[[ IR printer: ";
3057  if (getName().has_value())
3058  llvm::outs() << *getName() << " ";
3059 
3060  OpPrintingFlags printFlags;
3061  if (getAssumeVerified().value_or(false))
3062  printFlags.assumeVerified();
3063  if (getUseLocalScope().value_or(false))
3064  printFlags.useLocalScope();
3065  if (getSkipRegions().value_or(false))
3066  printFlags.skipRegions();
3067 
3068  if (!getTarget()) {
3069  llvm::outs() << "top-level ]]]\n";
3070  state.getTopLevel()->print(llvm::outs(), printFlags);
3071  llvm::outs() << "\n";
3072  llvm::outs().flush();
3074  }
3075 
3076  llvm::outs() << "]]]\n";
3077  for (Operation *target : state.getPayloadOps(getTarget())) {
3078  target->print(llvm::outs(), printFlags);
3079  llvm::outs() << "\n";
3080  }
3081 
3082  llvm::outs().flush();
3084 }
3085 
3086 void transform::PrintOp::getEffects(
3087  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3088  // We don't really care about mutability here, but `getTarget` now
3089  // unconditionally casts to a specific type before verification could run
3090  // here.
3091  if (!getTargetMutable().empty())
3092  onlyReadsHandle(getTargetMutable()[0], effects);
3093  onlyReadsPayload(effects);
3094 
3095  // There is no resource for stderr file descriptor, so just declare print
3096  // writes into the default resource.
3097  effects.emplace_back(MemoryEffects::Write::get());
3098 }
3099 
3100 //===----------------------------------------------------------------------===//
3101 // VerifyOp
3102 //===----------------------------------------------------------------------===//
3103 
3105 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
3106  Operation *target,
3108  transform::TransformState &state) {
3109  if (failed(::mlir::verify(target))) {
3111  << "failed to verify payload op";
3112  diag.attachNote(target->getLoc()) << "payload op";
3113  return diag;
3114  }
3116 }
3117 
3118 void transform::VerifyOp::getEffects(
3119  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3120  transform::onlyReadsHandle(getTargetMutable(), effects);
3121 }
3122 
3123 //===----------------------------------------------------------------------===//
3124 // YieldOp
3125 //===----------------------------------------------------------------------===//
3126 
3127 void transform::YieldOp::getEffects(
3128  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3129  onlyReadsHandle(getOperandsMutable(), effects);
3130 }
static ParseResult parseKeyValuePair(AsmParser &parser, DataLayoutEntryInterface &entry, bool tryType=false)
Parse an entry which can either be of the form key = value or a #dlti.dl_entry attribute.
Definition: DLTI.cpp:38
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue >> blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static ParseResult parseApplyRegisteredPassOptions(OpAsmParser &parser, DictionaryAttr &options, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dynamicOptions)
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DEBUG_TYPE_MATCHER
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, Operation *op, DictionaryAttr options, ValueRange dynamicOptions)
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ None
Zero or more operands with no delimiters.
@ Braces
{} brackets surrounding zero or more operands.
virtual ParseResult parseOptionalKeywordOrString(std::string *result)=0
Parse an optional keyword or string.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
MutableArrayRef< BlockArgument > BlockArgListType
Definition: Block.h:85
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:212
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:140
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:296
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
Definition: AsmPrinter.cpp:288
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:282
This is a value defined by a result of an operation.
Definition: Value.h:447
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
A wrapper class that allows for printing an operation with a set of flags, useful to act as a "stream...
Definition: Operation.h:1111
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getType() const
Definition: ValueRange.cpp:32
type_range getTypes() const
Definition: ValueRange.cpp:28
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getOpResults()
Definition: Operation.h:420
result_range getResults()
Definition: Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:53
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
Definition: PassManager.h:232
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:52
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
Definition: PassRegistry.h:58
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This is a "type erased" representation of a registered operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
Definition: Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
user_range getUsers() const
Definition: Value.h:218
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: WalkResult.h:51
static WalkResult interrupt()
Definition: WalkResult.h:46
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
static void printOptionValue(raw_ostream &os, const bool &value)
Utility methods for printing option values.
Definition: PassOptions.h:60
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument, ::mlir::ValueRange)> SequenceBodyBuilderArgsFn
Definition: TransformOps.h:39
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)> SequenceBodyBuilderFn
A builder function that populates the body of a SequenceOp.
Definition: TransformOps.h:36
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition: CSE.cpp:378
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.