MLIR  20.0.0git
TransformOps.cpp
Go to the documentation of this file.
1 //===- TransformOps.cpp - Transform dialect operations --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/Dominance.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/Verifier.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Pass/PassManager.h"
32 #include "mlir/Pass/PassRegistry.h"
33 #include "mlir/Transforms/CSE.h"
37 #include "llvm/ADT/DenseSet.h"
38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/ScopeExit.h"
40 #include "llvm/ADT/SmallPtrSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/Debug.h"
43 #include "llvm/Support/ErrorHandling.h"
44 #include <optional>
45 
46 #define DEBUG_TYPE "transform-dialect"
47 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
48 
49 #define DEBUG_TYPE_MATCHER "transform-matcher"
50 #define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ")
51 #define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x)
52 
53 using namespace mlir;
54 
55 static ParseResult parseSequenceOpOperands(
56  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
57  Type &rootType,
58  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
59  SmallVectorImpl<Type> &extraBindingTypes);
60 static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op,
61  Value root, Type rootType,
62  ValueRange extraBindings,
63  TypeRange extraBindingTypes);
64 static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op,
65  ArrayAttr matchers, ArrayAttr actions);
66 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
67  ArrayAttr &matchers,
68  ArrayAttr &actions);
69 
70 /// Helper function to check if the given transform op is contained in (or
71 /// equal to) the given payload target op. In that case, an error is returned.
72 /// Transforming transform IR that is currently executing is generally unsafe.
74 ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
75  Operation *payload) {
76  Operation *transformAncestor = transform.getOperation();
77  while (transformAncestor) {
78  if (transformAncestor == payload) {
80  transform.emitDefiniteFailure()
81  << "cannot apply transform to itself (or one of its ancestors)";
82  diag.attachNote(payload->getLoc()) << "target payload op";
83  return diag;
84  }
85  transformAncestor = transformAncestor->getParentOp();
86  }
88 }
89 
90 #define GET_OP_CLASSES
91 #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
92 
93 //===----------------------------------------------------------------------===//
94 // AlternativesOp
95 //===----------------------------------------------------------------------===//
96 
98 transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
99  if (!point.isParent() && getOperation()->getNumOperands() == 1)
100  return getOperation()->getOperands();
101  return OperandRange(getOperation()->operand_end(),
102  getOperation()->operand_end());
103 }
104 
105 void transform::AlternativesOp::getSuccessorRegions(
106  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
107  for (Region &alternative : llvm::drop_begin(
108  getAlternatives(),
109  point.isParent() ? 0
110  : point.getRegionOrNull()->getRegionNumber() + 1)) {
111  regions.emplace_back(&alternative, !getOperands().empty()
112  ? alternative.getArguments()
114  }
115  if (!point.isParent())
116  regions.emplace_back(getOperation()->getResults());
117 }
118 
119 void transform::AlternativesOp::getRegionInvocationBounds(
120  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
121  (void)operands;
122  // The region corresponding to the first alternative is always executed, the
123  // remaining may or may not be executed.
124  bounds.reserve(getNumRegions());
125  bounds.emplace_back(1, 1);
126  bounds.resize(getNumRegions(), InvocationBounds(0, 1));
127 }
128 
130  transform::TransformResults &results) {
131  for (const auto &res : block->getParentOp()->getOpResults())
132  results.set(res, {});
133 }
134 
136 transform::AlternativesOp::apply(transform::TransformRewriter &rewriter,
138  transform::TransformState &state) {
139  SmallVector<Operation *> originals;
140  if (Value scopeHandle = getScope())
141  llvm::append_range(originals, state.getPayloadOps(scopeHandle));
142  else
143  originals.push_back(state.getTopLevel());
144 
145  for (Operation *original : originals) {
146  if (original->isAncestor(getOperation())) {
147  auto diag = emitDefiniteFailure()
148  << "scope must not contain the transforms being applied";
149  diag.attachNote(original->getLoc()) << "scope";
150  return diag;
151  }
152  if (!original->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
153  auto diag = emitDefiniteFailure()
154  << "only isolated-from-above ops can be alternative scopes";
155  diag.attachNote(original->getLoc()) << "scope";
156  return diag;
157  }
158  }
159 
160  for (Region &reg : getAlternatives()) {
161  // Clone the scope operations and make the transforms in this alternative
162  // region apply to them by virtue of mapping the block argument (the only
163  // visible handle) to the cloned scope operations. This effectively prevents
164  // the transformation from accessing any IR outside the scope.
165  auto scope = state.make_region_scope(reg);
166  auto clones = llvm::to_vector(
167  llvm::map_range(originals, [](Operation *op) { return op->clone(); }));
168  auto deleteClones = llvm::make_scope_exit([&] {
169  for (Operation *clone : clones)
170  clone->erase();
171  });
172  if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones)))
174 
175  bool failed = false;
176  for (Operation &transform : reg.front().without_terminator()) {
178  state.applyTransform(cast<TransformOpInterface>(transform));
179  if (result.isSilenceableFailure()) {
180  LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage()
181  << "\n");
182  failed = true;
183  break;
184  }
185 
186  if (::mlir::failed(result.silence()))
188  }
189 
190  // If all operations in the given alternative succeeded, no need to consider
191  // the rest. Replace the original scoping operation with the clone on which
192  // the transformations were performed.
193  if (!failed) {
194  // We will be using the clones, so cancel their scheduled deletion.
195  deleteClones.release();
196  TrackingListener listener(state, *this);
197  IRRewriter rewriter(getContext(), &listener);
198  for (const auto &kvp : llvm::zip(originals, clones)) {
199  Operation *original = std::get<0>(kvp);
200  Operation *clone = std::get<1>(kvp);
201  original->getBlock()->getOperations().insert(original->getIterator(),
202  clone);
203  rewriter.replaceOp(original, clone->getResults());
204  }
205  detail::forwardTerminatorOperands(&reg.front(), state, results);
207  }
208  }
209  return emitSilenceableError() << "all alternatives failed";
210 }
211 
212 void transform::AlternativesOp::getEffects(
213  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
214  consumesHandle(getOperation()->getOpOperands(), effects);
215  producesHandle(getOperation()->getOpResults(), effects);
216  for (Region *region : getRegions()) {
217  if (!region->empty())
218  producesHandle(region->front().getArguments(), effects);
219  }
220  modifiesPayload(effects);
221 }
222 
223 LogicalResult transform::AlternativesOp::verify() {
224  for (Region &alternative : getAlternatives()) {
225  Block &block = alternative.front();
226  Operation *terminator = block.getTerminator();
227  if (terminator->getOperands().getTypes() != getResults().getTypes()) {
228  InFlightDiagnostic diag = emitOpError()
229  << "expects terminator operands to have the "
230  "same type as results of the operation";
231  diag.attachNote(terminator->getLoc()) << "terminator";
232  return diag;
233  }
234  }
235 
236  return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // AnnotateOp
241 //===----------------------------------------------------------------------===//
242 
244 transform::AnnotateOp::apply(transform::TransformRewriter &rewriter,
246  transform::TransformState &state) {
247  SmallVector<Operation *> targets =
248  llvm::to_vector(state.getPayloadOps(getTarget()));
249 
251  if (auto paramH = getParam()) {
252  ArrayRef<Attribute> params = state.getParams(paramH);
253  if (params.size() != 1) {
254  if (targets.size() != params.size()) {
255  return emitSilenceableError()
256  << "parameter and target have different payload lengths ("
257  << params.size() << " vs " << targets.size() << ")";
258  }
259  for (auto &&[target, attr] : llvm::zip_equal(targets, params))
260  target->setAttr(getName(), attr);
262  }
263  attr = params[0];
264  }
265  for (auto *target : targets)
266  target->setAttr(getName(), attr);
268 }
269 
270 void transform::AnnotateOp::getEffects(
271  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
272  onlyReadsHandle(getTargetMutable(), effects);
273  onlyReadsHandle(getParamMutable(), effects);
274  modifiesPayload(effects);
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // ApplyCommonSubexpressionEliminationOp
279 //===----------------------------------------------------------------------===//
280 
282 transform::ApplyCommonSubexpressionEliminationOp::applyToOne(
283  transform::TransformRewriter &rewriter, Operation *target,
284  ApplyToEachResultList &results, transform::TransformState &state) {
285  // Make sure that this transform is not applied to itself. Modifying the
286  // transform IR while it is being interpreted is generally dangerous.
287  DiagnosedSilenceableFailure payloadCheck =
289  if (!payloadCheck.succeeded())
290  return payloadCheck;
291 
292  DominanceInfo domInfo;
293  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target);
295 }
296 
297 void transform::ApplyCommonSubexpressionEliminationOp::getEffects(
298  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
299  transform::onlyReadsHandle(getTargetMutable(), effects);
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // ApplyDeadCodeEliminationOp
305 //===----------------------------------------------------------------------===//
306 
307 DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne(
308  transform::TransformRewriter &rewriter, Operation *target,
309  ApplyToEachResultList &results, transform::TransformState &state) {
310  // Make sure that this transform is not applied to itself. Modifying the
311  // transform IR while it is being interpreted is generally dangerous.
312  DiagnosedSilenceableFailure payloadCheck =
314  if (!payloadCheck.succeeded())
315  return payloadCheck;
316 
317  // Maintain a worklist of potentially dead ops.
318  SetVector<Operation *> worklist;
319 
320  // Helper function that adds all defining ops of used values (operands and
321  // operands of nested ops).
322  auto addDefiningOpsToWorklist = [&](Operation *op) {
323  op->walk([&](Operation *op) {
324  for (Value v : op->getOperands())
325  if (Operation *defOp = v.getDefiningOp())
326  if (target->isProperAncestor(defOp))
327  worklist.insert(defOp);
328  });
329  };
330 
331  // Helper function that erases an op.
332  auto eraseOp = [&](Operation *op) {
333  // Remove op and nested ops from the worklist.
334  op->walk([&](Operation *op) {
335  const auto *it = llvm::find(worklist, op);
336  if (it != worklist.end())
337  worklist.erase(it);
338  });
339  rewriter.eraseOp(op);
340  };
341 
342  // Initial walk over the IR.
343  target->walk<WalkOrder::PostOrder>([&](Operation *op) {
344  if (op != target && isOpTriviallyDead(op)) {
345  addDefiningOpsToWorklist(op);
346  eraseOp(op);
347  }
348  });
349 
350  // Erase all ops that have become dead.
351  while (!worklist.empty()) {
352  Operation *op = worklist.pop_back_val();
353  if (!isOpTriviallyDead(op))
354  continue;
355  addDefiningOpsToWorklist(op);
356  eraseOp(op);
357  }
358 
360 }
361 
362 void transform::ApplyDeadCodeEliminationOp::getEffects(
363  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
364  transform::onlyReadsHandle(getTargetMutable(), effects);
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // ApplyPatternsOp
370 //===----------------------------------------------------------------------===//
371 
372 DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
373  transform::TransformRewriter &rewriter, Operation *target,
374  ApplyToEachResultList &results, transform::TransformState &state) {
375  // Make sure that this transform is not applied to itself. Modifying the
376  // transform IR while it is being interpreted is generally dangerous. Even
377  // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver
378  // performs many additional simplifications such as dead code elimination.
379  DiagnosedSilenceableFailure payloadCheck =
381  if (!payloadCheck.succeeded())
382  return payloadCheck;
383 
384  // Gather all specified patterns.
385  MLIRContext *ctx = target->getContext();
386  RewritePatternSet patterns(ctx);
387  if (!getRegion().empty()) {
388  for (Operation &op : getRegion().front()) {
389  cast<transform::PatternDescriptorOpInterface>(&op)
390  .populatePatternsWithState(patterns, state);
391  }
392  }
393 
394  // Configure the GreedyPatternRewriteDriver.
395  GreedyRewriteConfig config;
396  config.listener =
397  static_cast<RewriterBase::Listener *>(rewriter.getListener());
398  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
399 
400  config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
402  : getMaxIterations();
403  config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
405  : getMaxNumRewrites();
406 
407  // Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
408  // was requested, apply the greedy pattern rewrite only once. (The greedy
409  // pattern rewrite driver already iterates to a fixpoint internally.)
410  bool cseChanged = false;
411  // One or two iterations should be sufficient. Stop iterating after a certain
412  // threshold to make debugging easier.
413  static const int64_t kNumMaxIterations = 50;
414  int64_t iteration = 0;
415  do {
416  LogicalResult result = failure();
417  if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
418  // Op is isolated from above. Apply patterns and also perform region
419  // simplification.
420  result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
421  } else {
422  // Manually gather list of ops because the other
423  // GreedyPatternRewriteDriver overloads only accepts ops that are isolated
424  // from above. This way, patterns can be applied to ops that are not
425  // isolated from above. Regions are not being simplified. Furthermore,
426  // only a single greedy rewrite iteration is performed.
428  target->walk([&](Operation *nestedOp) {
429  if (target != nestedOp)
430  ops.push_back(nestedOp);
431  });
432  result = applyOpPatternsAndFold(ops, frozenPatterns, config);
433  }
434 
435  // A failure typically indicates that the pattern application did not
436  // converge.
437  if (failed(result)) {
438  return emitSilenceableFailure(target)
439  << "greedy pattern application failed";
440  }
441 
442  if (getApplyCse()) {
443  DominanceInfo domInfo;
444  mlir::eliminateCommonSubExpressions(rewriter, domInfo, target,
445  &cseChanged);
446  }
447  } while (cseChanged && ++iteration < kNumMaxIterations);
448 
449  if (iteration == kNumMaxIterations)
450  return emitDefiniteFailure() << "fixpoint iteration did not converge";
451 
453 }
454 
455 LogicalResult transform::ApplyPatternsOp::verify() {
456  if (!getRegion().empty()) {
457  for (Operation &op : getRegion().front()) {
458  if (!isa<transform::PatternDescriptorOpInterface>(&op)) {
459  InFlightDiagnostic diag = emitOpError()
460  << "expected children ops to implement "
461  "PatternDescriptorOpInterface";
462  diag.attachNote(op.getLoc()) << "op without interface";
463  return diag;
464  }
465  }
466  }
467  return success();
468 }
469 
470 void transform::ApplyPatternsOp::getEffects(
471  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
472  transform::onlyReadsHandle(getTargetMutable(), effects);
474 }
475 
476 void transform::ApplyPatternsOp::build(
477  OpBuilder &builder, OperationState &result, Value target,
478  function_ref<void(OpBuilder &, Location)> bodyBuilder) {
479  result.addOperands(target);
480 
481  OpBuilder::InsertionGuard g(builder);
482  Region *region = result.addRegion();
483  builder.createBlock(region);
484  if (bodyBuilder)
485  bodyBuilder(builder, result.location);
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // ApplyCanonicalizationPatternsOp
490 //===----------------------------------------------------------------------===//
491 
492 void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
493  RewritePatternSet &patterns) {
494  MLIRContext *ctx = patterns.getContext();
495  for (Dialect *dialect : ctx->getLoadedDialects())
496  dialect->getCanonicalizationPatterns(patterns);
498  op.getCanonicalizationPatterns(patterns, ctx);
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // ApplyConversionPatternsOp
503 //===----------------------------------------------------------------------===//
504 
505 DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
508  MLIRContext *ctx = getContext();
509 
510  // Instantiate the default type converter if a type converter builder is
511  // specified.
512  std::unique_ptr<TypeConverter> defaultTypeConverter;
513  transform::TypeConverterBuilderOpInterface typeConverterBuilder =
514  getDefaultTypeConverter();
515  if (typeConverterBuilder)
516  defaultTypeConverter = typeConverterBuilder.getTypeConverter();
517 
518  // Configure conversion target.
519  ConversionTarget conversionTarget(*getContext());
520  if (getLegalOps())
521  for (Attribute attr : cast<ArrayAttr>(*getLegalOps()))
522  conversionTarget.addLegalOp(
523  OperationName(cast<StringAttr>(attr).getValue(), ctx));
524  if (getIllegalOps())
525  for (Attribute attr : cast<ArrayAttr>(*getIllegalOps()))
526  conversionTarget.addIllegalOp(
527  OperationName(cast<StringAttr>(attr).getValue(), ctx));
528  if (getLegalDialects())
529  for (Attribute attr : cast<ArrayAttr>(*getLegalDialects()))
530  conversionTarget.addLegalDialect(cast<StringAttr>(attr).getValue());
531  if (getIllegalDialects())
532  for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
533  conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
534 
535  // Gather all specified patterns.
536  RewritePatternSet patterns(ctx);
537  // Need to keep the converters alive until after pattern application because
538  // the patterns take a reference to an object that would otherwise get out of
539  // scope.
540  SmallVector<std::unique_ptr<TypeConverter>> keepAliveConverters;
541  if (!getPatterns().empty()) {
542  for (Operation &op : getPatterns().front()) {
543  auto descriptor =
544  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
545 
546  // Check if this pattern set specifies a type converter.
547  std::unique_ptr<TypeConverter> typeConverter =
548  descriptor.getTypeConverter();
549  TypeConverter *converter = nullptr;
550  if (typeConverter) {
551  keepAliveConverters.emplace_back(std::move(typeConverter));
552  converter = keepAliveConverters.back().get();
553  } else {
554  // No type converter specified: Use the default type converter.
555  if (!defaultTypeConverter) {
556  auto diag = emitDefiniteFailure()
557  << "pattern descriptor does not specify type "
558  "converter and apply_conversion_patterns op has "
559  "no default type converter";
560  diag.attachNote(op.getLoc()) << "pattern descriptor op";
561  return diag;
562  }
563  converter = defaultTypeConverter.get();
564  }
565 
566  // Add descriptor-specific updates to the conversion target, which may
567  // depend on the final type converter. In structural converters, the
568  // legality of types dictates the dynamic legality of an operation.
569  descriptor.populateConversionTargetRules(*converter, conversionTarget);
570 
571  descriptor.populatePatterns(*converter, patterns);
572  }
573  }
574 
575  // Attach a tracking listener if handles should be preserved. We configure the
576  // listener to allow op replacements with different names, as conversion
577  // patterns typically replace ops with replacement ops that have a different
578  // name.
579  TrackingListenerConfig trackingConfig;
580  trackingConfig.requireMatchingReplacementOpName = false;
581  ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
582  ConversionConfig conversionConfig;
583  if (getPreserveHandles())
584  conversionConfig.listener = &trackingListener;
585 
586  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
587  for (Operation *target : state.getPayloadOps(getTarget())) {
588  // Make sure that this transform is not applied to itself. Modifying the
589  // transform IR while it is being interpreted is generally dangerous.
590  DiagnosedSilenceableFailure payloadCheck =
592  if (!payloadCheck.succeeded())
593  return payloadCheck;
594 
595  LogicalResult status = failure();
596  if (getPartialConversion()) {
597  status = applyPartialConversion(target, conversionTarget, frozenPatterns,
598  conversionConfig);
599  } else {
600  status = applyFullConversion(target, conversionTarget, frozenPatterns,
601  conversionConfig);
602  }
603 
604  // Check dialect conversion state.
606  if (failed(status)) {
607  diag = emitSilenceableError() << "dialect conversion failed";
608  diag.attachNote(target->getLoc()) << "target op";
609  }
610 
611  // Check tracking listener error state.
612  DiagnosedSilenceableFailure trackingFailure =
613  trackingListener.checkAndResetError();
614  if (!trackingFailure.succeeded()) {
615  if (diag.succeeded()) {
616  // Tracking failure is the only failure.
617  return trackingFailure;
618  } else {
619  diag.attachNote() << "tracking listener also failed: "
620  << trackingFailure.getMessage();
621  (void)trackingFailure.silence();
622  }
623  }
624 
625  if (!diag.succeeded())
626  return diag;
627  }
628 
630 }
631 
633  if (getNumRegions() != 1 && getNumRegions() != 2)
634  return emitOpError() << "expected 1 or 2 regions";
635  if (!getPatterns().empty()) {
636  for (Operation &op : getPatterns().front()) {
637  if (!isa<transform::ConversionPatternDescriptorOpInterface>(&op)) {
639  emitOpError() << "expected pattern children ops to implement "
640  "ConversionPatternDescriptorOpInterface";
641  diag.attachNote(op.getLoc()) << "op without interface";
642  return diag;
643  }
644  }
645  }
646  if (getNumRegions() == 2) {
647  Region &typeConverterRegion = getRegion(1);
648  if (!llvm::hasSingleElement(typeConverterRegion.front()))
649  return emitOpError()
650  << "expected exactly one op in default type converter region";
651  Operation *maybeTypeConverter = &typeConverterRegion.front().front();
652  auto typeConverterOp = dyn_cast<transform::TypeConverterBuilderOpInterface>(
653  maybeTypeConverter);
654  if (!typeConverterOp) {
655  InFlightDiagnostic diag = emitOpError()
656  << "expected default converter child op to "
657  "implement TypeConverterBuilderOpInterface";
658  diag.attachNote(maybeTypeConverter->getLoc()) << "op without interface";
659  return diag;
660  }
661  // Check default type converter type.
662  if (!getPatterns().empty()) {
663  for (Operation &op : getPatterns().front()) {
664  auto descriptor =
665  cast<transform::ConversionPatternDescriptorOpInterface>(&op);
666  if (failed(descriptor.verifyTypeConverter(typeConverterOp)))
667  return failure();
668  }
669  }
670  }
671  return success();
672 }
673 
674 void transform::ApplyConversionPatternsOp::getEffects(
675  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
676  if (!getPreserveHandles()) {
677  transform::consumesHandle(getTargetMutable(), effects);
678  } else {
679  transform::onlyReadsHandle(getTargetMutable(), effects);
680  }
682 }
683 
684 void transform::ApplyConversionPatternsOp::build(
685  OpBuilder &builder, OperationState &result, Value target,
686  function_ref<void(OpBuilder &, Location)> patternsBodyBuilder,
687  function_ref<void(OpBuilder &, Location)> typeConverterBodyBuilder) {
688  result.addOperands(target);
689 
690  {
691  OpBuilder::InsertionGuard g(builder);
692  Region *region1 = result.addRegion();
693  builder.createBlock(region1);
694  if (patternsBodyBuilder)
695  patternsBodyBuilder(builder, result.location);
696  }
697  {
698  OpBuilder::InsertionGuard g(builder);
699  Region *region2 = result.addRegion();
700  builder.createBlock(region2);
701  if (typeConverterBodyBuilder)
702  typeConverterBodyBuilder(builder, result.location);
703  }
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // ApplyToLLVMConversionPatternsOp
708 //===----------------------------------------------------------------------===//
709 
710 void transform::ApplyToLLVMConversionPatternsOp::populatePatterns(
711  TypeConverter &typeConverter, RewritePatternSet &patterns) {
712  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
713  assert(dialect && "expected that dialect is loaded");
714  auto *iface = cast<ConvertToLLVMPatternInterface>(dialect);
715  // ConversionTarget is currently ignored because the enclosing
716  // apply_conversion_patterns op sets up its own ConversionTarget.
717  ConversionTarget target(*getContext());
718  iface->populateConvertToLLVMConversionPatterns(
719  target, static_cast<LLVMTypeConverter &>(typeConverter), patterns);
720 }
721 
722 LogicalResult transform::ApplyToLLVMConversionPatternsOp::verifyTypeConverter(
723  transform::TypeConverterBuilderOpInterface builder) {
724  if (builder.getTypeConverterType() != "LLVMTypeConverter")
725  return emitOpError("expected LLVMTypeConverter");
726  return success();
727 }
728 
730  Dialect *dialect = getContext()->getLoadedDialect(getDialectName());
731  if (!dialect)
732  return emitOpError("unknown dialect or dialect not loaded: ")
733  << getDialectName();
734  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
735  if (!iface)
736  return emitOpError(
737  "dialect does not implement ConvertToLLVMPatternInterface or "
738  "extension was not loaded: ")
739  << getDialectName();
740  return success();
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // ApplyLoopInvariantCodeMotionOp
745 //===----------------------------------------------------------------------===//
746 
748 transform::ApplyLoopInvariantCodeMotionOp::applyToOne(
749  transform::TransformRewriter &rewriter, LoopLikeOpInterface target,
751  transform::TransformState &state) {
752  // Currently, LICM does not remove operations, so we don't need tracking.
753  // If this ever changes, add a LICM entry point that takes a rewriter.
754  moveLoopInvariantCode(target);
756 }
757 
758 void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
759  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
760  transform::onlyReadsHandle(getTargetMutable(), effects);
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // ApplyRegisteredPassOp
766 //===----------------------------------------------------------------------===//
767 
768 DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
769  transform::TransformRewriter &rewriter, Operation *target,
770  ApplyToEachResultList &results, transform::TransformState &state) {
771  // Make sure that this transform is not applied to itself. Modifying the
772  // transform IR while it is being interpreted is generally dangerous. Even
773  // more so when applying passes because they may perform a wide range of IR
774  // modifications.
775  DiagnosedSilenceableFailure payloadCheck =
777  if (!payloadCheck.succeeded())
778  return payloadCheck;
779 
780  // Get pass or pass pipeline from registry.
781  const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
782  if (!info)
783  info = PassInfo::lookup(getPassName());
784  if (!info)
785  return emitDefiniteFailure()
786  << "unknown pass or pass pipeline: " << getPassName();
787 
788  // Create pass manager and run the pass or pass pipeline.
789  PassManager pm(getContext());
790  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
791  emitError(msg);
792  return failure();
793  }))) {
794  return emitDefiniteFailure()
795  << "failed to add pass or pass pipeline to pipeline: "
796  << getPassName();
797  }
798  if (failed(pm.run(target))) {
799  auto diag = emitSilenceableError() << "pass pipeline failed";
800  diag.attachNote(target->getLoc()) << "target op";
801  return diag;
802  }
803 
804  results.push_back(target);
806 }
807 
808 //===----------------------------------------------------------------------===//
809 // CastOp
810 //===----------------------------------------------------------------------===//
811 
813 transform::CastOp::applyToOne(transform::TransformRewriter &rewriter,
814  Operation *target, ApplyToEachResultList &results,
815  transform::TransformState &state) {
816  results.push_back(target);
818 }
819 
820 void transform::CastOp::getEffects(
821  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
822  onlyReadsPayload(effects);
823  onlyReadsHandle(getInputMutable(), effects);
824  producesHandle(getOperation()->getOpResults(), effects);
825 }
826 
827 bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
828  assert(inputs.size() == 1 && "expected one input");
829  assert(outputs.size() == 1 && "expected one output");
830  return llvm::all_of(
831  std::initializer_list<Type>{inputs.front(), outputs.front()},
832  llvm::IsaPred<transform::TransformHandleTypeInterface>);
833 }
834 
835 //===----------------------------------------------------------------------===//
836 // CollectMatchingOp
837 //===----------------------------------------------------------------------===//
838 
839 /// Applies matcher operations from the given `block` using
840 /// `blockArgumentMapping` to initialize block arguments. Updates `state`
841 /// accordingly. If any of the matcher produces a silenceable failure, discards
842 /// it (printing the content to the debug output stream) and returns failure. If
843 /// any of the matchers produces a definite failure, reports it and returns
844 /// failure. If all matchers in the block succeed, populates `mappings` with the
845 /// payload entities associated with the block terminator operands. Note that
846 /// `mappings` will be cleared before that.
849  ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
851  SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
852  assert(block.getParent() && "cannot match using a detached block");
853  auto matchScope = state.make_region_scope(*block.getParent());
854  if (failed(
855  state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
857 
858  for (Operation &match : block.without_terminator()) {
859  if (!isa<transform::MatchOpInterface>(match)) {
860  return emitDefiniteFailure(match.getLoc())
861  << "expected operations in the match part to "
862  "implement MatchOpInterface";
863  }
865  state.applyTransform(cast<transform::TransformOpInterface>(match));
866  if (diag.succeeded())
867  continue;
868 
869  return diag;
870  }
871 
872  // Remember the values mapped to the terminator operands so we can
873  // forward them to the action.
874  ValueRange yieldedValues = block.getTerminator()->getOperands();
875  // Our contract with the caller is that the mappings will contain only the
876  // newly mapped values, clear the rest.
877  mappings.clear();
878  transform::detail::prepareValueMappings(mappings, yieldedValues, state);
880 }
881 
882 /// Returns `true` if both types implement one of the interfaces provided as
883 /// template parameters.
884 template <typename... Tys>
885 static bool implementSameInterface(Type t1, Type t2) {
886  return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
887 }
888 
889 /// Returns `true` if both types implement one of the transform dialect
890 /// interfaces.
892  return implementSameInterface<transform::TransformHandleTypeInterface,
893  transform::TransformParamTypeInterface,
894  transform::TransformValueHandleTypeInterface>(
895  t1, t2);
896 }
897 
898 //===----------------------------------------------------------------------===//
899 // CollectMatchingOp
900 //===----------------------------------------------------------------------===//
901 
903 transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
905  transform::TransformState &state) {
906  auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
907  getOperation(), getMatcher());
908  if (matcher.isExternal()) {
909  return emitDefiniteFailure()
910  << "unresolved external symbol " << getMatcher();
911  }
912 
914  rawResults.resize(getOperation()->getNumResults());
915  std::optional<DiagnosedSilenceableFailure> maybeFailure;
916  for (Operation *root : state.getPayloadOps(getRoot())) {
917  WalkResult walkResult = root->walk([&](Operation *op) {
918  DEBUG_MATCHER({
919  DBGS_MATCHER() << "matching ";
920  op->print(llvm::dbgs(),
921  OpPrintingFlags().assumeVerified().skipRegions());
922  llvm::dbgs() << " @" << op << "\n";
923  });
924 
925  // Try matching.
927  SmallVector<transform::MappedValue> inputMapping({op});
929  matcher.getFunctionBody().front(),
930  ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
931  mappings);
932  if (diag.isDefiniteFailure())
933  return WalkResult::interrupt();
934  if (diag.isSilenceableFailure()) {
935  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
936  << " failed: " << diag.getMessage());
937  return WalkResult::advance();
938  }
939 
940  // If succeeded, collect results.
941  for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
942  if (mapping.size() != 1) {
943  maybeFailure.emplace(emitSilenceableError()
944  << "result #" << i << ", associated with "
945  << mapping.size()
946  << " payload objects, expected 1");
947  return WalkResult::interrupt();
948  }
949  rawResults[i].push_back(mapping[0]);
950  }
951  return WalkResult::advance();
952  });
953  if (walkResult.wasInterrupted())
954  return std::move(*maybeFailure);
955  assert(!maybeFailure && "failure set but the walk was not interrupted");
956 
957  for (auto &&[opResult, rawResult] :
958  llvm::zip_equal(getOperation()->getResults(), rawResults)) {
959  results.setMappedValues(opResult, rawResult);
960  }
961  }
963 }
964 
965 void transform::CollectMatchingOp::getEffects(
966  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
967  onlyReadsHandle(getRootMutable(), effects);
968  producesHandle(getOperation()->getOpResults(), effects);
969  onlyReadsPayload(effects);
970 }
971 
972 LogicalResult transform::CollectMatchingOp::verifySymbolUses(
973  SymbolTableCollection &symbolTable) {
974  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
975  symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
976  if (!matcherSymbol ||
977  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
978  return emitError() << "unresolved matcher symbol " << getMatcher();
979 
980  ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
981  if (argumentTypes.size() != 1 ||
982  !isa<TransformHandleTypeInterface>(argumentTypes[0])) {
983  return emitError()
984  << "expected the matcher to take one operation handle argument";
985  }
986  if (!matcherSymbol.getArgAttr(
987  0, transform::TransformDialect::kArgReadOnlyAttrName)) {
988  return emitError() << "expected the matcher argument to be marked readonly";
989  }
990 
991  ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
992  if (resultTypes.size() != getOperation()->getNumResults()) {
993  return emitError()
994  << "expected the matcher to yield as many values as op has results ("
995  << getOperation()->getNumResults() << "), got "
996  << resultTypes.size();
997  }
998 
999  for (auto &&[i, matcherType, resultType] :
1000  llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
1001  if (implementSameTransformInterface(matcherType, resultType))
1002  continue;
1003 
1004  return emitError()
1005  << "mismatching type interfaces for matcher result and op result #"
1006  << i;
1007  }
1008 
1009  return success();
1010 }
1011 
1012 //===----------------------------------------------------------------------===//
1013 // ForeachMatchOp
1014 //===----------------------------------------------------------------------===//
1015 
1016 // This is fine because nothing is actually consumed by this op.
1017 bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1018 
1020 transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
1021  transform::TransformResults &results,
1022  transform::TransformState &state) {
1024  matchActionPairs;
1025  matchActionPairs.reserve(getMatchers().size());
1026  SymbolTableCollection symbolTable;
1027  for (auto &&[matcher, action] :
1028  llvm::zip_equal(getMatchers(), getActions())) {
1029  auto matcherSymbol =
1030  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1031  getOperation(), cast<SymbolRefAttr>(matcher));
1032  auto actionSymbol =
1033  symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
1034  getOperation(), cast<SymbolRefAttr>(action));
1035  assert(matcherSymbol && actionSymbol &&
1036  "unresolved symbols not caught by the verifier");
1037 
1038  if (matcherSymbol.isExternal())
1039  return emitDefiniteFailure() << "unresolved external symbol " << matcher;
1040  if (actionSymbol.isExternal())
1041  return emitDefiniteFailure() << "unresolved external symbol " << action;
1042 
1043  matchActionPairs.emplace_back(matcherSymbol, actionSymbol);
1044  }
1045 
1046  DiagnosedSilenceableFailure overallDiag =
1048 
1049  SmallVector<SmallVector<MappedValue>> matchInputMapping;
1050  SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1051  SmallVector<SmallVector<MappedValue>> actionResultMapping;
1052  // Explicitly add the mapping for the first block argument (the op being
1053  // matched).
1054  matchInputMapping.emplace_back();
1055  transform::detail::prepareValueMappings(matchInputMapping,
1056  getForwardedInputs(), state);
1057  SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1058  actionResultMapping.resize(getForwardedOutputs().size());
1059 
1060  for (Operation *root : state.getPayloadOps(getRoot())) {
1061  WalkResult walkResult = root->walk([&](Operation *op) {
1062  // If getRestrictRoot is not present, skip over the root op itself so we
1063  // don't invalidate it.
1064  if (!getRestrictRoot() && op == root)
1065  return WalkResult::advance();
1066 
1067  DEBUG_MATCHER({
1068  DBGS_MATCHER() << "matching ";
1069  op->print(llvm::dbgs(),
1070  OpPrintingFlags().assumeVerified().skipRegions());
1071  llvm::dbgs() << " @" << op << "\n";
1072  });
1073 
1074  firstMatchArgument.clear();
1075  firstMatchArgument.push_back(op);
1076 
1077  // Try all the match/action pairs until the first successful match.
1078  for (auto [matcher, action] : matchActionPairs) {
1080  matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1081  state, matchOutputMapping);
1082  if (diag.isDefiniteFailure())
1083  return WalkResult::interrupt();
1084  if (diag.isSilenceableFailure()) {
1085  DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
1086  << " failed: " << diag.getMessage());
1087  continue;
1088  }
1089 
1090  auto scope = state.make_region_scope(action.getFunctionBody());
1091  if (failed(state.mapBlockArguments(
1092  action.getFunctionBody().front().getArguments(),
1093  matchOutputMapping))) {
1094  return WalkResult::interrupt();
1095  }
1096 
1097  for (Operation &transform :
1098  action.getFunctionBody().front().without_terminator()) {
1100  state.applyTransform(cast<TransformOpInterface>(transform));
1101  if (result.isDefiniteFailure())
1102  return WalkResult::interrupt();
1103  if (result.isSilenceableFailure()) {
1104  if (overallDiag.succeeded()) {
1105  overallDiag = emitSilenceableError() << "actions failed";
1106  }
1107  overallDiag.attachNote(action->getLoc())
1108  << "failed action: " << result.getMessage();
1109  overallDiag.attachNote(op->getLoc())
1110  << "when applied to this matching payload";
1111  (void)result.silence();
1112  continue;
1113  }
1114  }
1115  if (failed(detail::appendValueMappings(
1116  MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1117  action.getFunctionBody().front().getTerminator()->getOperands(),
1118  state, getFlattenResults()))) {
1120  << "action @" << action.getName()
1121  << " has results associated with multiple payload entities, "
1122  "but flattening was not requested";
1123  return WalkResult::interrupt();
1124  }
1125  break;
1126  }
1127  return WalkResult::advance();
1128  });
1129  if (walkResult.wasInterrupted())
1131  }
1132 
1133  // The root operation should not have been affected, so we can just reassign
1134  // the payload to the result. Note that we need to consume the root handle to
1135  // make sure any handles to operations inside, that could have been affected
1136  // by actions, are invalidated.
1137  results.set(llvm::cast<OpResult>(getUpdated()),
1138  state.getPayloadOps(getRoot()));
1139  for (auto &&[result, mapping] :
1140  llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1141  results.setMappedValues(result, mapping);
1142  }
1143  return overallDiag;
1144 }
1145 
1146 void transform::ForeachMatchOp::getAsmResultNames(
1147  OpAsmSetValueNameFn setNameFn) {
1148  setNameFn(getUpdated(), "updated_root");
1149  for (Value v : getForwardedOutputs()) {
1150  setNameFn(v, "yielded");
1151  }
1152 }
1153 
1154 void transform::ForeachMatchOp::getEffects(
1155  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1156  // Bail if invalid.
1157  if (getOperation()->getNumOperands() < 1 ||
1158  getOperation()->getNumResults() < 1) {
1159  return modifiesPayload(effects);
1160  }
1161 
1162  consumesHandle(getRootMutable(), effects);
1163  onlyReadsHandle(getForwardedInputsMutable(), effects);
1164  producesHandle(getOperation()->getOpResults(), effects);
1165  modifiesPayload(effects);
1166 }
1167 
1168 /// Parses the comma-separated list of symbol reference pairs of the format
1169 /// `@matcher -> @action`.
1170 static ParseResult parseForeachMatchSymbols(OpAsmParser &parser,
1171  ArrayAttr &matchers,
1172  ArrayAttr &actions) {
1173  StringAttr matcher;
1174  StringAttr action;
1175  SmallVector<Attribute> matcherList;
1176  SmallVector<Attribute> actionList;
1177  do {
1178  if (parser.parseSymbolName(matcher) || parser.parseArrow() ||
1179  parser.parseSymbolName(action)) {
1180  return failure();
1181  }
1182  matcherList.push_back(SymbolRefAttr::get(matcher));
1183  actionList.push_back(SymbolRefAttr::get(action));
1184  } while (parser.parseOptionalComma().succeeded());
1185 
1186  matchers = parser.getBuilder().getArrayAttr(matcherList);
1187  actions = parser.getBuilder().getArrayAttr(actionList);
1188  return success();
1189 }
1190 
1191 /// Prints the comma-separated list of symbol reference pairs of the format
1192 /// `@matcher -> @action`.
1194  ArrayAttr matchers, ArrayAttr actions) {
1195  printer.increaseIndent();
1196  printer.increaseIndent();
1197  for (auto &&[matcher, action, idx] : llvm::zip_equal(
1198  matchers, actions, llvm::seq<unsigned>(0, matchers.size()))) {
1199  printer.printNewline();
1200  printer << cast<SymbolRefAttr>(matcher) << " -> "
1201  << cast<SymbolRefAttr>(action);
1202  if (idx != matchers.size() - 1)
1203  printer << ", ";
1204  }
1205  printer.decreaseIndent();
1206  printer.decreaseIndent();
1207 }
1208 
1209 LogicalResult transform::ForeachMatchOp::verify() {
1210  if (getMatchers().size() != getActions().size())
1211  return emitOpError() << "expected the same number of matchers and actions";
1212  if (getMatchers().empty())
1213  return emitOpError() << "expected at least one match/action pair";
1214 
1215  llvm::SmallPtrSet<Attribute, 8> matcherNames;
1216  for (Attribute name : getMatchers()) {
1217  if (matcherNames.insert(name).second)
1218  continue;
1219  emitWarning() << "matcher " << name
1220  << " is used more than once, only the first match will apply";
1221  }
1222 
1223  return success();
1224 }
1225 
1226 /// Checks that the attributes of the function-like operation have correct
1227 /// consumption effect annotations. If `alsoVerifyInternal`, checks for
1228 /// annotations being present even if they can be inferred from the body.
1230 verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings,
1231  bool alsoVerifyInternal = false) {
1232  auto transformOp = cast<transform::TransformOpInterface>(op.getOperation());
1233  llvm::SmallDenseSet<unsigned> consumedArguments;
1234  if (!op.isExternal()) {
1235  transform::getConsumedBlockArguments(op.getFunctionBody().front(),
1236  consumedArguments);
1237  }
1238  for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
1239  bool isConsumed =
1240  op.getArgAttr(i, transform::TransformDialect::kArgConsumedAttrName) !=
1241  nullptr;
1242  bool isReadOnly =
1243  op.getArgAttr(i, transform::TransformDialect::kArgReadOnlyAttrName) !=
1244  nullptr;
1245  if (isConsumed && isReadOnly) {
1246  return transformOp.emitSilenceableError()
1247  << "argument #" << i << " cannot be both readonly and consumed";
1248  }
1249  if ((op.isExternal() || alsoVerifyInternal) && !isConsumed && !isReadOnly) {
1250  return transformOp.emitSilenceableError()
1251  << "must provide consumed/readonly status for arguments of "
1252  "external or called ops";
1253  }
1254  if (op.isExternal())
1255  continue;
1256 
1257  if (consumedArguments.contains(i) && !isConsumed && isReadOnly) {
1258  return transformOp.emitSilenceableError()
1259  << "argument #" << i
1260  << " is consumed in the body but is not marked as such";
1261  }
1262  if (emitWarnings && !consumedArguments.contains(i) && isConsumed) {
1263  // Cannot use op.emitWarning() here as it would attempt to verify the op
1264  // before printing, resulting in infinite recursion.
1265  emitWarning(op->getLoc())
1266  << "op argument #" << i
1267  << " is not consumed in the body but is marked as consumed";
1268  }
1269  }
1271 }
1272 
1273 LogicalResult transform::ForeachMatchOp::verifySymbolUses(
1274  SymbolTableCollection &symbolTable) {
1275  assert(getMatchers().size() == getActions().size());
1276  auto consumedAttr =
1277  StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
1278  for (auto &&[matcher, action] :
1279  llvm::zip_equal(getMatchers(), getActions())) {
1280  // Presence and typing.
1281  auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
1282  symbolTable.lookupNearestSymbolFrom(getOperation(),
1283  cast<SymbolRefAttr>(matcher)));
1284  auto actionSymbol = dyn_cast_or_null<FunctionOpInterface>(
1285  symbolTable.lookupNearestSymbolFrom(getOperation(),
1286  cast<SymbolRefAttr>(action)));
1287  if (!matcherSymbol ||
1288  !isa<TransformOpInterface>(matcherSymbol.getOperation()))
1289  return emitError() << "unresolved matcher symbol " << matcher;
1290  if (!actionSymbol ||
1291  !isa<TransformOpInterface>(actionSymbol.getOperation()))
1292  return emitError() << "unresolved action symbol " << action;
1293 
1294  if (failed(verifyFunctionLikeConsumeAnnotations(matcherSymbol,
1295  /*emitWarnings=*/false,
1296  /*alsoVerifyInternal=*/true)
1297  .checkAndReport())) {
1298  return failure();
1299  }
1300  if (failed(verifyFunctionLikeConsumeAnnotations(actionSymbol,
1301  /*emitWarnings=*/false,
1302  /*alsoVerifyInternal=*/true)
1303  .checkAndReport())) {
1304  return failure();
1305  }
1306 
1307  // Input -> matcher forwarding.
1308  TypeRange operandTypes = getOperandTypes();
1309  TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1310  if (operandTypes.size() != matcherArguments.size()) {
1312  emitError() << "the number of operands (" << operandTypes.size()
1313  << ") doesn't match the number of matcher arguments ("
1314  << matcherArguments.size() << ") for " << matcher;
1315  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1316  return diag;
1317  }
1318  for (auto &&[i, operand, argument] :
1319  llvm::enumerate(operandTypes, matcherArguments)) {
1320  if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1322  emitOpError()
1323  << "does not expect matcher symbol to consume its operand #" << i;
1324  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1325  return diag;
1326  }
1327 
1328  if (implementSameTransformInterface(operand, argument))
1329  continue;
1330 
1332  emitError()
1333  << "mismatching type interfaces for operand and matcher argument #"
1334  << i << " of matcher " << matcher;
1335  diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1336  return diag;
1337  }
1338 
1339  // Matcher -> action forwarding.
1340  TypeRange matcherResults = matcherSymbol.getResultTypes();
1341  TypeRange actionArguments = actionSymbol.getArgumentTypes();
1342  if (matcherResults.size() != actionArguments.size()) {
1343  return emitError() << "mismatching number of matcher results and "
1344  "action arguments between "
1345  << matcher << " (" << matcherResults.size() << ") and "
1346  << action << " (" << actionArguments.size() << ")";
1347  }
1348  for (auto &&[i, matcherType, actionType] :
1349  llvm::enumerate(matcherResults, actionArguments)) {
1350  if (implementSameTransformInterface(matcherType, actionType))
1351  continue;
1352 
1353  return emitError() << "mismatching type interfaces for matcher result "
1354  "and action argument #"
1355  << i << "of matcher " << matcher << " and action "
1356  << action;
1357  }
1358 
1359  // Action -> result forwarding.
1360  TypeRange actionResults = actionSymbol.getResultTypes();
1361  auto resultTypes = TypeRange(getResultTypes()).drop_front();
1362  if (actionResults.size() != resultTypes.size()) {
1364  emitError() << "the number of action results ("
1365  << actionResults.size() << ") for " << action
1366  << " doesn't match the number of extra op results ("
1367  << resultTypes.size() << ")";
1368  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1369  return diag;
1370  }
1371  for (auto &&[i, resultType, actionType] :
1372  llvm::enumerate(resultTypes, actionResults)) {
1373  if (implementSameTransformInterface(resultType, actionType))
1374  continue;
1375 
1377  emitError() << "mismatching type interfaces for action result #" << i
1378  << " of action " << action << " and op result";
1379  diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
1380  return diag;
1381  }
1382  }
1383  return success();
1384 }
1385 
1386 //===----------------------------------------------------------------------===//
1387 // ForeachOp
1388 //===----------------------------------------------------------------------===//
1389 
1391 transform::ForeachOp::apply(transform::TransformRewriter &rewriter,
1392  transform::TransformResults &results,
1393  transform::TransformState &state) {
1394  // We store the payloads before executing the body as ops may be removed from
1395  // the mapping by the TrackingRewriter while iteration is in progress.
1397  detail::prepareValueMappings(payloads, getTargets(), state);
1398  size_t numIterations = payloads.empty() ? 0 : payloads.front().size();
1399  bool withZipShortest = getWithZipShortest();
1400 
1401  // In case of `zip_shortest`, set the number of iterations to the
1402  // smallest payload in the targets.
1403  if (withZipShortest) {
1404  numIterations =
1405  llvm::min_element(payloads, [&](const SmallVector<MappedValue> &A,
1406  const SmallVector<MappedValue> &B) {
1407  return A.size() < B.size();
1408  })->size();
1409 
1410  for (size_t argIdx = 0; argIdx < payloads.size(); argIdx++)
1411  payloads[argIdx].resize(numIterations);
1412  }
1413 
1414  // As we will be "zipping" over them, check all payloads have the same size.
1415  // `zip_shortest` adjusts all payloads to the same size, so skip this check
1416  // when true.
1417  for (size_t argIdx = 1; !withZipShortest && argIdx < payloads.size();
1418  argIdx++) {
1419  if (payloads[argIdx].size() != numIterations) {
1420  return emitSilenceableError()
1421  << "prior targets' payload size (" << numIterations
1422  << ") differs from payload size (" << payloads[argIdx].size()
1423  << ") of target " << getTargets()[argIdx];
1424  }
1425  }
1426 
1427  // Start iterating, indexing into payloads to obtain the right arguments to
1428  // call the body with - each slice of payloads at the same argument index
1429  // corresponding to a tuple to use as the body's block arguments.
1430  ArrayRef<BlockArgument> blockArguments = getBody().front().getArguments();
1431  SmallVector<SmallVector<MappedValue>> zippedResults(getNumResults(), {});
1432  for (size_t iterIdx = 0; iterIdx < numIterations; iterIdx++) {
1433  auto scope = state.make_region_scope(getBody());
1434  // Set up arguments to the region's block.
1435  for (auto &&[argIdx, blockArg] : llvm::enumerate(blockArguments)) {
1436  MappedValue argument = payloads[argIdx][iterIdx];
1437  // Note that each blockArg's handle gets associated with just a single
1438  // element from the corresponding target's payload.
1439  if (failed(state.mapBlockArgument(blockArg, {argument})))
1441  }
1442 
1443  // Execute loop body.
1444  for (Operation &transform : getBody().front().without_terminator()) {
1445  DiagnosedSilenceableFailure result = state.applyTransform(
1446  llvm::cast<transform::TransformOpInterface>(transform));
1447  if (!result.succeeded())
1448  return result;
1449  }
1450 
1451  // Append yielded payloads to corresponding results from prior iterations.
1452  OperandRange yieldOperands = getYieldOp().getOperands();
1453  for (auto &&[result, yieldOperand, resTuple] :
1454  llvm::zip_equal(getResults(), yieldOperands, zippedResults))
1455  // NB: each iteration we add any number of ops/vals/params to a result.
1456  if (isa<TransformHandleTypeInterface>(result.getType()))
1457  llvm::append_range(resTuple, state.getPayloadOps(yieldOperand));
1458  else if (isa<TransformValueHandleTypeInterface>(result.getType()))
1459  llvm::append_range(resTuple, state.getPayloadValues(yieldOperand));
1460  else if (isa<TransformParamTypeInterface>(result.getType()))
1461  llvm::append_range(resTuple, state.getParams(yieldOperand));
1462  else
1463  assert(false && "unhandled handle type");
1464  }
1465 
1466  // Associate the accumulated result payloads to the op's actual results.
1467  for (auto &&[result, resPayload] : zip_equal(getResults(), zippedResults))
1468  results.setMappedValues(llvm::cast<OpResult>(result), resPayload);
1469 
1471 }
1472 
1473 void transform::ForeachOp::getEffects(
1474  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1475  // NB: this `zip` should be `zip_equal` - while this op's verifier catches
1476  // arity errors, this method might get called before/in absence of `verify()`.
1477  for (auto &&[target, blockArg] :
1478  llvm::zip(getTargetsMutable(), getBody().front().getArguments())) {
1479  BlockArgument blockArgument = blockArg;
1480  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1481  return isHandleConsumed(blockArgument,
1482  cast<TransformOpInterface>(&op));
1483  })) {
1484  consumesHandle(target, effects);
1485  } else {
1486  onlyReadsHandle(target, effects);
1487  }
1488  }
1489 
1490  if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1491  return doesModifyPayload(cast<TransformOpInterface>(&op));
1492  })) {
1493  modifiesPayload(effects);
1494  } else if (any_of(getBody().front().without_terminator(), [&](Operation &op) {
1495  return doesReadPayload(cast<TransformOpInterface>(&op));
1496  })) {
1497  onlyReadsPayload(effects);
1498  }
1499 
1500  producesHandle(getOperation()->getOpResults(), effects);
1501 }
1502 
1503 void transform::ForeachOp::getSuccessorRegions(
1504  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
1505  Region *bodyRegion = &getBody();
1506  if (point.isParent()) {
1507  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1508  return;
1509  }
1510 
1511  // Branch back to the region or the parent.
1512  assert(point == getBody() && "unexpected region index");
1513  regions.emplace_back(bodyRegion, bodyRegion->getArguments());
1514  regions.emplace_back();
1515 }
1516 
1518 transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
1519  // Each block argument handle is mapped to a subset (one op to be precise)
1520  // of the payload of the corresponding `targets` operand of ForeachOp.
1521  assert(point == getBody() && "unexpected region index");
1522  return getOperation()->getOperands();
1523 }
1524 
1525 transform::YieldOp transform::ForeachOp::getYieldOp() {
1526  return cast<transform::YieldOp>(getBody().front().getTerminator());
1527 }
1528 
1529 LogicalResult transform::ForeachOp::verify() {
1530  for (auto [targetOpt, bodyArgOpt] :
1531  llvm::zip_longest(getTargets(), getBody().front().getArguments())) {
1532  if (!targetOpt || !bodyArgOpt)
1533  return emitOpError() << "expects the same number of targets as the body "
1534  "has block arguments";
1535  if (targetOpt.value().getType() != bodyArgOpt.value().getType())
1536  return emitOpError(
1537  "expects co-indexed targets and the body's "
1538  "block arguments to have the same op/value/param type");
1539  }
1540 
1541  for (auto [resultOpt, yieldOperandOpt] :
1542  llvm::zip_longest(getResults(), getYieldOp().getOperands())) {
1543  if (!resultOpt || !yieldOperandOpt)
1544  return emitOpError() << "expects the same number of results as the "
1545  "yield terminator has operands";
1546  if (resultOpt.value().getType() != yieldOperandOpt.value().getType())
1547  return emitOpError("expects co-indexed results and yield "
1548  "operands to have the same op/value/param type");
1549  }
1550 
1551  return success();
1552 }
1553 
1554 //===----------------------------------------------------------------------===//
1555 // GetParentOp
1556 //===----------------------------------------------------------------------===//
1557 
1559 transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
1560  transform::TransformResults &results,
1561  transform::TransformState &state) {
1562  SmallVector<Operation *> parents;
1563  DenseSet<Operation *> resultSet;
1564  for (Operation *target : state.getPayloadOps(getTarget())) {
1565  Operation *parent = target;
1566  for (int64_t i = 0, e = getNthParent(); i < e; ++i) {
1567  parent = parent->getParentOp();
1568  while (parent) {
1569  bool checkIsolatedFromAbove =
1570  !getIsolatedFromAbove() ||
1572  bool checkOpName = !getOpName().has_value() ||
1573  parent->getName().getStringRef() == *getOpName();
1574  if (checkIsolatedFromAbove && checkOpName)
1575  break;
1576  parent = parent->getParentOp();
1577  }
1578  if (!parent) {
1579  if (getAllowEmptyResults()) {
1580  results.set(llvm::cast<OpResult>(getResult()), parents);
1582  }
1584  emitSilenceableError()
1585  << "could not find a parent op that matches all requirements";
1586  diag.attachNote(target->getLoc()) << "target op";
1587  return diag;
1588  }
1589  }
1590  if (getDeduplicate()) {
1591  if (!resultSet.contains(parent)) {
1592  parents.push_back(parent);
1593  resultSet.insert(parent);
1594  }
1595  } else {
1596  parents.push_back(parent);
1597  }
1598  }
1599  results.set(llvm::cast<OpResult>(getResult()), parents);
1601 }
1602 
1603 //===----------------------------------------------------------------------===//
1604 // GetConsumersOfResult
1605 //===----------------------------------------------------------------------===//
1606 
1608 transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter,
1609  transform::TransformResults &results,
1610  transform::TransformState &state) {
1611  int64_t resultNumber = getResultNumber();
1612  auto payloadOps = state.getPayloadOps(getTarget());
1613  if (std::empty(payloadOps)) {
1614  results.set(cast<OpResult>(getResult()), {});
1616  }
1617  if (!llvm::hasSingleElement(payloadOps))
1618  return emitDefiniteFailure()
1619  << "handle must be mapped to exactly one payload op";
1620 
1621  Operation *target = *payloadOps.begin();
1622  if (target->getNumResults() <= resultNumber)
1623  return emitDefiniteFailure() << "result number overflow";
1624  results.set(llvm::cast<OpResult>(getResult()),
1625  llvm::to_vector(target->getResult(resultNumber).getUsers()));
1627 }
1628 
1629 //===----------------------------------------------------------------------===//
1630 // GetDefiningOp
1631 //===----------------------------------------------------------------------===//
1632 
1634 transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter,
1635  transform::TransformResults &results,
1636  transform::TransformState &state) {
1637  SmallVector<Operation *> definingOps;
1638  for (Value v : state.getPayloadValues(getTarget())) {
1639  if (llvm::isa<BlockArgument>(v)) {
1641  emitSilenceableError() << "cannot get defining op of block argument";
1642  diag.attachNote(v.getLoc()) << "target value";
1643  return diag;
1644  }
1645  definingOps.push_back(v.getDefiningOp());
1646  }
1647  results.set(llvm::cast<OpResult>(getResult()), definingOps);
1649 }
1650 
1651 //===----------------------------------------------------------------------===//
1652 // GetProducerOfOperand
1653 //===----------------------------------------------------------------------===//
1654 
1656 transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter,
1657  transform::TransformResults &results,
1658  transform::TransformState &state) {
1659  int64_t operandNumber = getOperandNumber();
1660  SmallVector<Operation *> producers;
1661  for (Operation *target : state.getPayloadOps(getTarget())) {
1662  Operation *producer =
1663  target->getNumOperands() <= operandNumber
1664  ? nullptr
1665  : target->getOperand(operandNumber).getDefiningOp();
1666  if (!producer) {
1668  emitSilenceableError()
1669  << "could not find a producer for operand number: " << operandNumber
1670  << " of " << *target;
1671  diag.attachNote(target->getLoc()) << "target op";
1672  return diag;
1673  }
1674  producers.push_back(producer);
1675  }
1676  results.set(llvm::cast<OpResult>(getResult()), producers);
1678 }
1679 
1680 //===----------------------------------------------------------------------===//
1681 // GetOperandOp
1682 //===----------------------------------------------------------------------===//
1683 
1685 transform::GetOperandOp::apply(transform::TransformRewriter &rewriter,
1686  transform::TransformResults &results,
1687  transform::TransformState &state) {
1688  SmallVector<Value> operands;
1689  for (Operation *target : state.getPayloadOps(getTarget())) {
1690  SmallVector<int64_t> operandPositions;
1692  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1693  target->getNumOperands(), operandPositions);
1694  if (diag.isSilenceableFailure()) {
1695  diag.attachNote(target->getLoc())
1696  << "while considering positions of this payload operation";
1697  return diag;
1698  }
1699  llvm::append_range(operands,
1700  llvm::map_range(operandPositions, [&](int64_t pos) {
1701  return target->getOperand(pos);
1702  }));
1703  }
1704  results.setValues(cast<OpResult>(getResult()), operands);
1706 }
1707 
1708 LogicalResult transform::GetOperandOp::verify() {
1709  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1710  getIsInverted(), getIsAll());
1711 }
1712 
1713 //===----------------------------------------------------------------------===//
1714 // GetResultOp
1715 //===----------------------------------------------------------------------===//
1716 
1718 transform::GetResultOp::apply(transform::TransformRewriter &rewriter,
1719  transform::TransformResults &results,
1720  transform::TransformState &state) {
1721  SmallVector<Value> opResults;
1722  for (Operation *target : state.getPayloadOps(getTarget())) {
1723  SmallVector<int64_t> resultPositions;
1725  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
1726  target->getNumResults(), resultPositions);
1727  if (diag.isSilenceableFailure()) {
1728  diag.attachNote(target->getLoc())
1729  << "while considering positions of this payload operation";
1730  return diag;
1731  }
1732  llvm::append_range(opResults,
1733  llvm::map_range(resultPositions, [&](int64_t pos) {
1734  return target->getResult(pos);
1735  }));
1736  }
1737  results.setValues(cast<OpResult>(getResult()), opResults);
1739 }
1740 
1741 LogicalResult transform::GetResultOp::verify() {
1742  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
1743  getIsInverted(), getIsAll());
1744 }
1745 
1746 //===----------------------------------------------------------------------===//
1747 // GetTypeOp
1748 //===----------------------------------------------------------------------===//
1749 
1750 void transform::GetTypeOp::getEffects(
1751  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1752  onlyReadsHandle(getValueMutable(), effects);
1753  producesHandle(getOperation()->getOpResults(), effects);
1754  onlyReadsPayload(effects);
1755 }
1756 
1758 transform::GetTypeOp::apply(transform::TransformRewriter &rewriter,
1759  transform::TransformResults &results,
1760  transform::TransformState &state) {
1761  SmallVector<Attribute> params;
1762  for (Value value : state.getPayloadValues(getValue())) {
1763  Type type = value.getType();
1764  if (getElemental()) {
1765  if (auto shaped = dyn_cast<ShapedType>(type)) {
1766  type = shaped.getElementType();
1767  }
1768  }
1769  params.push_back(TypeAttr::get(type));
1770  }
1771  results.setParams(cast<OpResult>(getResult()), params);
1773 }
1774 
1775 //===----------------------------------------------------------------------===//
1776 // IncludeOp
1777 //===----------------------------------------------------------------------===//
1778 
1779 /// Applies the transform ops contained in `block`. Maps `results` to the same
1780 /// values as the operands of the block terminator.
1782 applySequenceBlock(Block &block, transform::FailurePropagationMode mode,
1784  transform::TransformResults &results) {
1785  // Apply the sequenced ops one by one.
1786  for (Operation &transform : block.without_terminator()) {
1788  state.applyTransform(cast<transform::TransformOpInterface>(transform));
1789  if (result.isDefiniteFailure())
1790  return result;
1791 
1792  if (result.isSilenceableFailure()) {
1793  if (mode == transform::FailurePropagationMode::Propagate) {
1794  // Propagate empty results in case of early exit.
1795  forwardEmptyOperands(&block, state, results);
1796  return result;
1797  }
1798  (void)result.silence();
1799  }
1800  }
1801 
1802  // Forward the operation mapping for values yielded from the sequence to the
1803  // values produced by the sequence op.
1804  transform::detail::forwardTerminatorOperands(&block, state, results);
1806 }
1807 
1809 transform::IncludeOp::apply(transform::TransformRewriter &rewriter,
1810  transform::TransformResults &results,
1811  transform::TransformState &state) {
1812  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1813  getOperation(), getTarget());
1814  assert(callee && "unverified reference to unknown symbol");
1815 
1816  if (callee.isExternal())
1817  return emitDefiniteFailure() << "unresolved external named sequence";
1818 
1819  // Map operands to block arguments.
1821  detail::prepareValueMappings(mappings, getOperands(), state);
1822  auto scope = state.make_region_scope(callee.getBody());
1823  for (auto &&[arg, map] :
1824  llvm::zip_equal(callee.getBody().front().getArguments(), mappings)) {
1825  if (failed(state.mapBlockArgument(arg, map)))
1827  }
1828 
1830  callee.getBody().front(), getFailurePropagationMode(), state, results);
1831  mappings.clear();
1833  mappings, callee.getBody().front().getTerminator()->getOperands(), state);
1834  for (auto &&[result, mapping] : llvm::zip_equal(getResults(), mappings))
1835  results.setMappedValues(result, mapping);
1836  return result;
1837 }
1838 
1840 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings);
1841 
1842 void transform::IncludeOp::getEffects(
1843  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1844  // Always mark as modifying the payload.
1845  // TODO: a mechanism to annotate effects on payload. Even when all handles are
1846  // only read, the payload may still be modified, so we currently stay on the
1847  // conservative side and always indicate modification. This may prevent some
1848  // code reordering.
1849  modifiesPayload(effects);
1850 
1851  // Results are always produced.
1852  producesHandle(getOperation()->getOpResults(), effects);
1853 
1854  // Adds default effects to operands and results. This will be added if
1855  // preconditions fail so the trait verifier doesn't complain about missing
1856  // effects and the real precondition failure is reported later on.
1857  auto defaultEffects = [&] {
1858  onlyReadsHandle(getOperation()->getOpOperands(), effects);
1859  };
1860 
1861  // Bail if the callee is unknown. This may run as part of the verification
1862  // process before we verified the validity of the callee or of this op.
1863  auto target =
1864  getOperation()->getAttrOfType<SymbolRefAttr>(getTargetAttrName());
1865  if (!target)
1866  return defaultEffects();
1867  auto callee = SymbolTable::lookupNearestSymbolFrom<NamedSequenceOp>(
1868  getOperation(), getTarget());
1869  if (!callee)
1870  return defaultEffects();
1871  DiagnosedSilenceableFailure earlyVerifierResult =
1872  verifyNamedSequenceOp(callee, /*emitWarnings=*/false);
1873  if (!earlyVerifierResult.succeeded()) {
1874  (void)earlyVerifierResult.silence();
1875  return defaultEffects();
1876  }
1877 
1878  for (unsigned i = 0, e = getNumOperands(); i < e; ++i) {
1879  if (callee.getArgAttr(i, TransformDialect::kArgConsumedAttrName))
1880  consumesHandle(getOperation()->getOpOperand(i), effects);
1881  else
1882  onlyReadsHandle(getOperation()->getOpOperand(i), effects);
1883  }
1884 }
1885 
1886 LogicalResult
1887 transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1888  // Access through indirection and do additional checking because this may be
1889  // running before the main op verifier.
1890  auto targetAttr = getOperation()->getAttrOfType<SymbolRefAttr>("target");
1891  if (!targetAttr)
1892  return emitOpError() << "expects a 'target' symbol reference attribute";
1893 
1894  auto target = symbolTable.lookupNearestSymbolFrom<transform::NamedSequenceOp>(
1895  *this, targetAttr);
1896  if (!target)
1897  return emitOpError() << "does not reference a named transform sequence";
1898 
1899  FunctionType fnType = target.getFunctionType();
1900  if (fnType.getNumInputs() != getNumOperands())
1901  return emitError("incorrect number of operands for callee");
1902 
1903  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
1904  if (getOperand(i).getType() != fnType.getInput(i)) {
1905  return emitOpError("operand type mismatch: expected operand type ")
1906  << fnType.getInput(i) << ", but provided "
1907  << getOperand(i).getType() << " for operand number " << i;
1908  }
1909  }
1910 
1911  if (fnType.getNumResults() != getNumResults())
1912  return emitError("incorrect number of results for callee");
1913 
1914  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
1915  Type resultType = getResult(i).getType();
1916  Type funcType = fnType.getResult(i);
1917  if (!implementSameTransformInterface(resultType, funcType)) {
1918  return emitOpError() << "type of result #" << i
1919  << " must implement the same transform dialect "
1920  "interface as the corresponding callee result";
1921  }
1922  }
1923 
1925  cast<FunctionOpInterface>(*target), /*emitWarnings=*/false,
1926  /*alsoVerifyInternal=*/true)
1927  .checkAndReport();
1928 }
1929 
1930 //===----------------------------------------------------------------------===//
1931 // MatchOperationEmptyOp
1932 //===----------------------------------------------------------------------===//
1933 
1934 DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
1935  ::std::optional<::mlir::Operation *> maybeCurrent,
1937  if (!maybeCurrent.has_value()) {
1938  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; });
1940  }
1941  DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; });
1942  return emitSilenceableError() << "operation is not empty";
1943 }
1944 
1945 //===----------------------------------------------------------------------===//
1946 // MatchOperationNameOp
1947 //===----------------------------------------------------------------------===//
1948 
1949 DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation(
1950  Operation *current, transform::TransformResults &results,
1951  transform::TransformState &state) {
1952  StringRef currentOpName = current->getName().getStringRef();
1953  for (auto acceptedAttr : getOpNames().getAsRange<StringAttr>()) {
1954  if (acceptedAttr.getValue() == currentOpName)
1956  }
1957  return emitSilenceableError() << "wrong operation name";
1958 }
1959 
1960 //===----------------------------------------------------------------------===//
1961 // MatchParamCmpIOp
1962 //===----------------------------------------------------------------------===//
1963 
1965 transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter,
1966  transform::TransformResults &results,
1967  transform::TransformState &state) {
1968  auto signedAPIntAsString = [&](const APInt &value) {
1969  std::string str;
1970  llvm::raw_string_ostream os(str);
1971  value.print(os, /*isSigned=*/true);
1972  return os.str();
1973  };
1974 
1975  ArrayRef<Attribute> params = state.getParams(getParam());
1976  ArrayRef<Attribute> references = state.getParams(getReference());
1977 
1978  if (params.size() != references.size()) {
1979  return emitSilenceableError()
1980  << "parameters have different payload lengths (" << params.size()
1981  << " vs " << references.size() << ")";
1982  }
1983 
1984  for (auto &&[i, param, reference] : llvm::enumerate(params, references)) {
1985  auto intAttr = llvm::dyn_cast<IntegerAttr>(param);
1986  auto refAttr = llvm::dyn_cast<IntegerAttr>(reference);
1987  if (!intAttr || !refAttr) {
1988  return emitDefiniteFailure()
1989  << "non-integer parameter value not expected";
1990  }
1991  if (intAttr.getType() != refAttr.getType()) {
1992  return emitDefiniteFailure()
1993  << "mismatching integer attribute types in parameter #" << i;
1994  }
1995  APInt value = intAttr.getValue();
1996  APInt refValue = refAttr.getValue();
1997 
1998  // TODO: this copy will not be necessary in C++20.
1999  int64_t position = i;
2000  auto reportError = [&](StringRef direction) {
2002  emitSilenceableError() << "expected parameter to be " << direction
2003  << " " << signedAPIntAsString(refValue)
2004  << ", got " << signedAPIntAsString(value);
2005  diag.attachNote(getParam().getLoc())
2006  << "value # " << position
2007  << " associated with the parameter defined here";
2008  return diag;
2009  };
2010 
2011  switch (getPredicate()) {
2012  case MatchCmpIPredicate::eq:
2013  if (value.eq(refValue))
2014  break;
2015  return reportError("equal to");
2016  case MatchCmpIPredicate::ne:
2017  if (value.ne(refValue))
2018  break;
2019  return reportError("not equal to");
2020  case MatchCmpIPredicate::lt:
2021  if (value.slt(refValue))
2022  break;
2023  return reportError("less than");
2024  case MatchCmpIPredicate::le:
2025  if (value.sle(refValue))
2026  break;
2027  return reportError("less than or equal to");
2028  case MatchCmpIPredicate::gt:
2029  if (value.sgt(refValue))
2030  break;
2031  return reportError("greater than");
2032  case MatchCmpIPredicate::ge:
2033  if (value.sge(refValue))
2034  break;
2035  return reportError("greater than or equal to");
2036  }
2037  }
2039 }
2040 
2041 void transform::MatchParamCmpIOp::getEffects(
2042  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2043  onlyReadsHandle(getParamMutable(), effects);
2044  onlyReadsHandle(getReferenceMutable(), effects);
2045 }
2046 
2047 //===----------------------------------------------------------------------===//
2048 // ParamConstantOp
2049 //===----------------------------------------------------------------------===//
2050 
2052 transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter,
2053  transform::TransformResults &results,
2054  transform::TransformState &state) {
2055  results.setParams(cast<OpResult>(getParam()), {getValue()});
2057 }
2058 
2059 //===----------------------------------------------------------------------===//
2060 // MergeHandlesOp
2061 //===----------------------------------------------------------------------===//
2062 
2064 transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter,
2065  transform::TransformResults &results,
2066  transform::TransformState &state) {
2067  ValueRange handles = getHandles();
2068  if (isa<TransformHandleTypeInterface>(handles.front().getType())) {
2069  SmallVector<Operation *> operations;
2070  for (Value operand : handles)
2071  llvm::append_range(operations, state.getPayloadOps(operand));
2072  if (!getDeduplicate()) {
2073  results.set(llvm::cast<OpResult>(getResult()), operations);
2075  }
2076 
2077  SetVector<Operation *> uniqued(operations.begin(), operations.end());
2078  results.set(llvm::cast<OpResult>(getResult()), uniqued.getArrayRef());
2080  }
2081 
2082  if (llvm::isa<TransformParamTypeInterface>(handles.front().getType())) {
2083  SmallVector<Attribute> attrs;
2084  for (Value attribute : handles)
2085  llvm::append_range(attrs, state.getParams(attribute));
2086  if (!getDeduplicate()) {
2087  results.setParams(cast<OpResult>(getResult()), attrs);
2089  }
2090 
2091  SetVector<Attribute> uniqued(attrs.begin(), attrs.end());
2092  results.setParams(cast<OpResult>(getResult()), uniqued.getArrayRef());
2094  }
2095 
2096  assert(
2097  llvm::isa<TransformValueHandleTypeInterface>(handles.front().getType()) &&
2098  "expected value handle type");
2099  SmallVector<Value> payloadValues;
2100  for (Value value : handles)
2101  llvm::append_range(payloadValues, state.getPayloadValues(value));
2102  if (!getDeduplicate()) {
2103  results.setValues(cast<OpResult>(getResult()), payloadValues);
2105  }
2106 
2107  SetVector<Value> uniqued(payloadValues.begin(), payloadValues.end());
2108  results.setValues(cast<OpResult>(getResult()), uniqued.getArrayRef());
2110 }
2111 
2112 bool transform::MergeHandlesOp::allowsRepeatedHandleOperands() {
2113  // Handles may be the same if deduplicating is enabled.
2114  return getDeduplicate();
2115 }
2116 
2117 void transform::MergeHandlesOp::getEffects(
2118  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2119  onlyReadsHandle(getHandlesMutable(), effects);
2120  producesHandle(getOperation()->getOpResults(), effects);
2121 
2122  // There are no effects on the Payload IR as this is only a handle
2123  // manipulation.
2124 }
2125 
2126 OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) {
2127  if (getDeduplicate() || getHandles().size() != 1)
2128  return {};
2129 
2130  // If deduplication is not required and there is only one operand, it can be
2131  // used directly instead of merging.
2132  return getHandles().front();
2133 }
2134 
2135 //===----------------------------------------------------------------------===//
2136 // NamedSequenceOp
2137 //===----------------------------------------------------------------------===//
2138 
2140 transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter,
2141  transform::TransformResults &results,
2142  transform::TransformState &state) {
2143  if (isExternal())
2144  return emitDefiniteFailure() << "unresolved external named sequence";
2145 
2146  // Map the entry block argument to the list of operations.
2147  // Note: this is the same implementation as PossibleTopLevelTransformOp but
2148  // without attaching the interface / trait since that is tailored to a
2149  // dangling top-level op that does not get "called".
2150  auto scope = state.make_region_scope(getBody());
2152  state, this->getOperation(), getBody())))
2154 
2155  return applySequenceBlock(getBody().front(),
2156  FailurePropagationMode::Propagate, state, results);
2157 }
2158 
2159 void transform::NamedSequenceOp::getEffects(
2160  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {}
2161 
2163  OperationState &result) {
2165  parser, result, /*allowVariadic=*/false,
2166  getFunctionTypeAttrName(result.name),
2167  [](Builder &builder, ArrayRef<Type> inputs, ArrayRef<Type> results,
2169  std::string &) { return builder.getFunctionType(inputs, results); },
2170  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2171 }
2172 
2175  printer, cast<FunctionOpInterface>(getOperation()), /*isVariadic=*/false,
2176  getFunctionTypeAttrName().getValue(), getArgAttrsAttrName(),
2177  getResAttrsAttrName());
2178 }
2179 
2180 /// Verifies that a symbol function-like transform dialect operation has the
2181 /// signature and the terminator that have conforming types, i.e., types
2182 /// implementing the same transform dialect type interface. If `allowExternal`
2183 /// is set, allow external symbols (declarations) and don't check the terminator
2184 /// as it may not exist.
2186 verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal) {
2187  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2190  << "cannot be defined inside another transform op";
2191  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2192  return diag;
2193  }
2194 
2195  if (op.isExternal() || op.getFunctionBody().empty()) {
2196  if (allowExternal)
2198 
2199  return emitSilenceableFailure(op) << "cannot be external";
2200  }
2201 
2202  if (op.getFunctionBody().front().empty())
2203  return emitSilenceableFailure(op) << "expected a non-empty body block";
2204 
2205  Operation *terminator = &op.getFunctionBody().front().back();
2206  if (!isa<transform::YieldOp>(terminator)) {
2208  << "expected '"
2209  << transform::YieldOp::getOperationName()
2210  << "' as terminator";
2211  diag.attachNote(terminator->getLoc()) << "terminator";
2212  return diag;
2213  }
2214 
2215  if (terminator->getNumOperands() != op.getResultTypes().size()) {
2216  return emitSilenceableFailure(terminator)
2217  << "expected terminator to have as many operands as the parent op "
2218  "has results";
2219  }
2220  for (auto [i, operandType, resultType] : llvm::zip_equal(
2221  llvm::seq<unsigned>(0, terminator->getNumOperands()),
2222  terminator->getOperands().getType(), op.getResultTypes())) {
2223  if (operandType == resultType)
2224  continue;
2225  return emitSilenceableFailure(terminator)
2226  << "the type of the terminator operand #" << i
2227  << " must match the type of the corresponding parent op result ("
2228  << operandType << " vs " << resultType << ")";
2229  }
2230 
2232 }
2233 
2234 /// Verification of a NamedSequenceOp. This does not report the error
2235 /// immediately, so it can be used to check for op's well-formedness before the
2236 /// verifier runs, e.g., during trait verification.
2238 verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings) {
2239  if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
2240  if (!parent->getAttr(
2241  transform::TransformDialect::kWithNamedSequenceAttrName)) {
2244  << "expects the parent symbol table to have the '"
2245  << transform::TransformDialect::kWithNamedSequenceAttrName
2246  << "' attribute";
2247  diag.attachNote(parent->getLoc()) << "symbol table operation";
2248  return diag;
2249  }
2250  }
2251 
2252  if (auto parent = op->getParentOfType<transform::TransformOpInterface>()) {
2255  << "cannot be defined inside another transform op";
2256  diag.attachNote(parent.getLoc()) << "ancestor transform op";
2257  return diag;
2258  }
2259 
2260  if (op.isExternal() || op.getBody().empty())
2261  return verifyFunctionLikeConsumeAnnotations(cast<FunctionOpInterface>(*op),
2262  emitWarnings);
2263 
2264  if (op.getBody().front().empty())
2265  return emitSilenceableFailure(op) << "expected a non-empty body block";
2266 
2267  Operation *terminator = &op.getBody().front().back();
2268  if (!isa<transform::YieldOp>(terminator)) {
2270  << "expected '"
2271  << transform::YieldOp::getOperationName()
2272  << "' as terminator";
2273  diag.attachNote(terminator->getLoc()) << "terminator";
2274  return diag;
2275  }
2276 
2277  if (terminator->getNumOperands() != op.getFunctionType().getNumResults()) {
2278  return emitSilenceableFailure(terminator)
2279  << "expected terminator to have as many operands as the parent op "
2280  "has results";
2281  }
2282  for (auto [i, operandType, resultType] :
2283  llvm::zip_equal(llvm::seq<unsigned>(0, terminator->getNumOperands()),
2284  terminator->getOperands().getType(),
2285  op.getFunctionType().getResults())) {
2286  if (operandType == resultType)
2287  continue;
2288  return emitSilenceableFailure(terminator)
2289  << "the type of the terminator operand #" << i
2290  << " must match the type of the corresponding parent op result ("
2291  << operandType << " vs " << resultType << ")";
2292  }
2293 
2294  auto funcOp = cast<FunctionOpInterface>(*op);
2296  verifyFunctionLikeConsumeAnnotations(funcOp, emitWarnings);
2297  if (!diag.succeeded())
2298  return diag;
2299 
2300  return verifyYieldingSingleBlockOp(funcOp,
2301  /*allowExternal=*/true);
2302 }
2303 
2304 LogicalResult transform::NamedSequenceOp::verify() {
2305  // Actual verification happens in a separate function for reusability.
2306  return verifyNamedSequenceOp(*this, /*emitWarnings=*/true).checkAndReport();
2307 }
2308 
2309 template <typename FnTy>
2310 static void buildSequenceBody(OpBuilder &builder, OperationState &state,
2311  Type bbArgType, TypeRange extraBindingTypes,
2312  FnTy bodyBuilder) {
2313  SmallVector<Type> types;
2314  types.reserve(1 + extraBindingTypes.size());
2315  types.push_back(bbArgType);
2316  llvm::append_range(types, extraBindingTypes);
2317 
2318  OpBuilder::InsertionGuard guard(builder);
2319  Region *region = state.regions.back().get();
2320  Block *bodyBlock =
2321  builder.createBlock(region, region->begin(), types,
2322  SmallVector<Location>(types.size(), state.location));
2323 
2324  // Populate body.
2325  builder.setInsertionPointToStart(bodyBlock);
2326  if constexpr (llvm::function_traits<FnTy>::num_args == 3) {
2327  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
2328  } else {
2329  bodyBuilder(builder, state.location, bodyBlock->getArgument(0),
2330  bodyBlock->getArguments().drop_front());
2331  }
2332 }
2333 
2334 void transform::NamedSequenceOp::build(OpBuilder &builder,
2335  OperationState &state, StringRef symName,
2336  Type rootType, TypeRange resultTypes,
2337  SequenceBodyBuilderFn bodyBuilder,
2339  ArrayRef<DictionaryAttr> argAttrs) {
2340  state.addAttribute(SymbolTable::getSymbolAttrName(),
2341  builder.getStringAttr(symName));
2342  state.addAttribute(getFunctionTypeAttrName(state.name),
2344  rootType, resultTypes)));
2345  state.attributes.append(attrs.begin(), attrs.end());
2346  state.addRegion();
2347 
2348  buildSequenceBody(builder, state, rootType,
2349  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2350 }
2351 
2352 //===----------------------------------------------------------------------===//
2353 // NumAssociationsOp
2354 //===----------------------------------------------------------------------===//
2355 
2357 transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
2358  transform::TransformResults &results,
2359  transform::TransformState &state) {
2360  size_t numAssociations =
2362  .Case([&](TransformHandleTypeInterface opHandle) {
2363  return llvm::range_size(state.getPayloadOps(getHandle()));
2364  })
2365  .Case([&](TransformValueHandleTypeInterface valueHandle) {
2366  return llvm::range_size(state.getPayloadValues(getHandle()));
2367  })
2368  .Case([&](TransformParamTypeInterface param) {
2369  return llvm::range_size(state.getParams(getHandle()));
2370  })
2371  .Default([](Type) {
2372  llvm_unreachable("unknown kind of transform dialect type");
2373  return 0;
2374  });
2375  results.setParams(cast<OpResult>(getNum()),
2376  rewriter.getI64IntegerAttr(numAssociations));
2378 }
2379 
2380 LogicalResult transform::NumAssociationsOp::verify() {
2381  // Verify that the result type accepts an i64 attribute as payload.
2382  auto resultType = cast<TransformParamTypeInterface>(getNum().getType());
2383  return resultType
2384  .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
2385  .checkAndReport();
2386 }
2387 
2388 //===----------------------------------------------------------------------===//
2389 // SelectOp
2390 //===----------------------------------------------------------------------===//
2391 
2393 transform::SelectOp::apply(transform::TransformRewriter &rewriter,
2394  transform::TransformResults &results,
2395  transform::TransformState &state) {
2396  SmallVector<Operation *> result;
2397  auto payloadOps = state.getPayloadOps(getTarget());
2398  for (Operation *op : payloadOps) {
2399  if (op->getName().getStringRef() == getOpName())
2400  result.push_back(op);
2401  }
2402  results.set(cast<OpResult>(getResult()), result);
2404 }
2405 
2406 //===----------------------------------------------------------------------===//
2407 // SplitHandleOp
2408 //===----------------------------------------------------------------------===//
2409 
2410 void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result,
2411  Value target, int64_t numResultHandles) {
2412  result.addOperands(target);
2413  result.addTypes(SmallVector<Type>(numResultHandles, target.getType()));
2414 }
2415 
2417 transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter,
2418  transform::TransformResults &results,
2419  transform::TransformState &state) {
2420  int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle()));
2421  auto produceNumOpsError = [&]() {
2422  return emitSilenceableError()
2423  << getHandle() << " expected to contain " << this->getNumResults()
2424  << " payload ops but it contains " << numPayloadOps
2425  << " payload ops";
2426  };
2427 
2428  // Fail if there are more payload ops than results and no overflow result was
2429  // specified.
2430  if (numPayloadOps > getNumResults() && !getOverflowResult().has_value())
2431  return produceNumOpsError();
2432 
2433  // Fail if there are more results than payload ops. Unless:
2434  // - "fail_on_payload_too_small" is set to "false", or
2435  // - "pass_through_empty_handle" is set to "true" and there are 0 payload ops.
2436  if (numPayloadOps < getNumResults() && getFailOnPayloadTooSmall() &&
2437  (numPayloadOps != 0 || !getPassThroughEmptyHandle()))
2438  return produceNumOpsError();
2439 
2440  // Distribute payload ops.
2441  SmallVector<SmallVector<Operation *, 1>> resultHandles(getNumResults(), {});
2442  if (getOverflowResult())
2443  resultHandles[*getOverflowResult()].reserve(numPayloadOps -
2444  getNumResults());
2445  for (auto &&en : llvm::enumerate(state.getPayloadOps(getHandle()))) {
2446  int64_t resultNum = en.index();
2447  if (resultNum >= getNumResults())
2448  resultNum = *getOverflowResult();
2449  resultHandles[resultNum].push_back(en.value());
2450  }
2451 
2452  // Set transform op results.
2453  for (auto &&it : llvm::enumerate(resultHandles))
2454  results.set(llvm::cast<OpResult>(getResult(it.index())), it.value());
2455 
2457 }
2458 
2459 void transform::SplitHandleOp::getEffects(
2460  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2461  onlyReadsHandle(getHandleMutable(), effects);
2462  producesHandle(getOperation()->getOpResults(), effects);
2463  // There are no effects on the Payload IR as this is only a handle
2464  // manipulation.
2465 }
2466 
2467 LogicalResult transform::SplitHandleOp::verify() {
2468  if (getOverflowResult().has_value() &&
2469  !(*getOverflowResult() < getNumResults()))
2470  return emitOpError("overflow_result is not a valid result index");
2471  return success();
2472 }
2473 
2474 //===----------------------------------------------------------------------===//
2475 // ReplicateOp
2476 //===----------------------------------------------------------------------===//
2477 
2479 transform::ReplicateOp::apply(transform::TransformRewriter &rewriter,
2480  transform::TransformResults &results,
2481  transform::TransformState &state) {
2482  unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern()));
2483  for (const auto &en : llvm::enumerate(getHandles())) {
2484  Value handle = en.value();
2485  if (isa<TransformHandleTypeInterface>(handle.getType())) {
2486  SmallVector<Operation *> current =
2487  llvm::to_vector(state.getPayloadOps(handle));
2488  SmallVector<Operation *> payload;
2489  payload.reserve(numRepetitions * current.size());
2490  for (unsigned i = 0; i < numRepetitions; ++i)
2491  llvm::append_range(payload, current);
2492  results.set(llvm::cast<OpResult>(getReplicated()[en.index()]), payload);
2493  } else {
2494  assert(llvm::isa<TransformParamTypeInterface>(handle.getType()) &&
2495  "expected param type");
2496  ArrayRef<Attribute> current = state.getParams(handle);
2497  SmallVector<Attribute> params;
2498  params.reserve(numRepetitions * current.size());
2499  for (unsigned i = 0; i < numRepetitions; ++i)
2500  llvm::append_range(params, current);
2501  results.setParams(llvm::cast<OpResult>(getReplicated()[en.index()]),
2502  params);
2503  }
2504  }
2506 }
2507 
2508 void transform::ReplicateOp::getEffects(
2509  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2510  onlyReadsHandle(getPatternMutable(), effects);
2511  onlyReadsHandle(getHandlesMutable(), effects);
2512  producesHandle(getOperation()->getOpResults(), effects);
2513 }
2514 
2515 //===----------------------------------------------------------------------===//
2516 // SequenceOp
2517 //===----------------------------------------------------------------------===//
2518 
2520 transform::SequenceOp::apply(transform::TransformRewriter &rewriter,
2521  transform::TransformResults &results,
2522  transform::TransformState &state) {
2523  // Map the entry block argument to the list of operations.
2524  auto scope = state.make_region_scope(*getBodyBlock()->getParent());
2525  if (failed(mapBlockArguments(state)))
2527 
2528  return applySequenceBlock(*getBodyBlock(), getFailurePropagationMode(), state,
2529  results);
2530 }
2531 
2532 static ParseResult parseSequenceOpOperands(
2533  OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
2534  Type &rootType,
2535  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &extraBindings,
2536  SmallVectorImpl<Type> &extraBindingTypes) {
2537  OpAsmParser::UnresolvedOperand rootOperand;
2538  OptionalParseResult hasRoot = parser.parseOptionalOperand(rootOperand);
2539  if (!hasRoot.has_value()) {
2540  root = std::nullopt;
2541  return success();
2542  }
2543  if (failed(hasRoot.value()))
2544  return failure();
2545  root = rootOperand;
2546 
2547  if (succeeded(parser.parseOptionalComma())) {
2548  if (failed(parser.parseOperandList(extraBindings)))
2549  return failure();
2550  }
2551  if (failed(parser.parseColon()))
2552  return failure();
2553 
2554  // The paren is truly optional.
2555  (void)parser.parseOptionalLParen();
2556 
2557  if (failed(parser.parseType(rootType))) {
2558  return failure();
2559  }
2560 
2561  if (!extraBindings.empty()) {
2562  if (parser.parseComma() || parser.parseTypeList(extraBindingTypes))
2563  return failure();
2564  }
2565 
2566  if (extraBindingTypes.size() != extraBindings.size()) {
2567  return parser.emitError(parser.getNameLoc(),
2568  "expected types to be provided for all operands");
2569  }
2570 
2571  // The paren is truly optional.
2572  (void)parser.parseOptionalRParen();
2573  return success();
2574 }
2575 
2577  Value root, Type rootType,
2578  ValueRange extraBindings,
2579  TypeRange extraBindingTypes) {
2580  if (!root)
2581  return;
2582 
2583  printer << root;
2584  bool hasExtras = !extraBindings.empty();
2585  if (hasExtras) {
2586  printer << ", ";
2587  printer.printOperands(extraBindings);
2588  }
2589 
2590  printer << " : ";
2591  if (hasExtras)
2592  printer << "(";
2593 
2594  printer << rootType;
2595  if (hasExtras) {
2596  printer << ", ";
2597  llvm::interleaveComma(extraBindingTypes, printer.getStream());
2598  printer << ")";
2599  }
2600 }
2601 
2602 /// Returns `true` if the given op operand may be consuming the handle value in
2603 /// the Transform IR. That is, if it may have a Free effect on it.
2605  // Conservatively assume the effect being present in absence of the interface.
2606  auto iface = dyn_cast<transform::TransformOpInterface>(use.getOwner());
2607  if (!iface)
2608  return true;
2609 
2610  return isHandleConsumed(use.get(), iface);
2611 }
2612 
2613 LogicalResult
2615  function_ref<InFlightDiagnostic()> reportError) {
2616  OpOperand *potentialConsumer = nullptr;
2617  for (OpOperand &use : value.getUses()) {
2618  if (!isValueUsePotentialConsumer(use))
2619  continue;
2620 
2621  if (!potentialConsumer) {
2622  potentialConsumer = &use;
2623  continue;
2624  }
2625 
2626  InFlightDiagnostic diag = reportError()
2627  << " has more than one potential consumer";
2628  diag.attachNote(potentialConsumer->getOwner()->getLoc())
2629  << "used here as operand #" << potentialConsumer->getOperandNumber();
2630  diag.attachNote(use.getOwner()->getLoc())
2631  << "used here as operand #" << use.getOperandNumber();
2632  return diag;
2633  }
2634 
2635  return success();
2636 }
2637 
2638 LogicalResult transform::SequenceOp::verify() {
2639  assert(getBodyBlock()->getNumArguments() >= 1 &&
2640  "the number of arguments must have been verified to be more than 1 by "
2641  "PossibleTopLevelTransformOpTrait");
2642 
2643  if (!getRoot() && !getExtraBindings().empty()) {
2644  return emitOpError()
2645  << "does not expect extra operands when used as top-level";
2646  }
2647 
2648  // Check if a block argument has more than one consuming use.
2649  for (BlockArgument arg : getBodyBlock()->getArguments()) {
2650  if (failed(checkDoubleConsume(arg, [this, arg]() {
2651  return (emitOpError() << "block argument #" << arg.getArgNumber());
2652  }))) {
2653  return failure();
2654  }
2655  }
2656 
2657  // Check properties of the nested operations they cannot check themselves.
2658  for (Operation &child : *getBodyBlock()) {
2659  if (!isa<TransformOpInterface>(child) &&
2660  &child != &getBodyBlock()->back()) {
2662  emitOpError()
2663  << "expected children ops to implement TransformOpInterface";
2664  diag.attachNote(child.getLoc()) << "op without interface";
2665  return diag;
2666  }
2667 
2668  for (OpResult result : child.getResults()) {
2669  auto report = [&]() {
2670  return (child.emitError() << "result #" << result.getResultNumber());
2671  };
2672  if (failed(checkDoubleConsume(result, report)))
2673  return failure();
2674  }
2675  }
2676 
2677  if (!getBodyBlock()->mightHaveTerminator())
2678  return emitOpError() << "expects to have a terminator in the body";
2679 
2680  if (getBodyBlock()->getTerminator()->getOperandTypes() !=
2681  getOperation()->getResultTypes()) {
2682  InFlightDiagnostic diag = emitOpError()
2683  << "expects the types of the terminator operands "
2684  "to match the types of the result";
2685  diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
2686  return diag;
2687  }
2688  return success();
2689 }
2690 
2691 void transform::SequenceOp::getEffects(
2692  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2693  getPotentialTopLevelEffects(effects);
2694 }
2695 
2697 transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2698  assert(point == getBody() && "unexpected region index");
2699  if (getOperation()->getNumOperands() > 0)
2700  return getOperation()->getOperands();
2701  return OperandRange(getOperation()->operand_end(),
2702  getOperation()->operand_end());
2703 }
2704 
2705 void transform::SequenceOp::getSuccessorRegions(
2706  RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2707  if (point.isParent()) {
2708  Region *bodyRegion = &getBody();
2709  regions.emplace_back(bodyRegion, getNumOperands() != 0
2710  ? bodyRegion->getArguments()
2712  return;
2713  }
2714 
2715  assert(point == getBody() && "unexpected region index");
2716  regions.emplace_back(getOperation()->getResults());
2717 }
2718 
2719 void transform::SequenceOp::getRegionInvocationBounds(
2720  ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
2721  (void)operands;
2722  bounds.emplace_back(1, 1);
2723 }
2724 
2725 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2726  TypeRange resultTypes,
2727  FailurePropagationMode failurePropagationMode,
2728  Value root,
2729  SequenceBodyBuilderFn bodyBuilder) {
2730  build(builder, state, resultTypes, failurePropagationMode, root,
2731  /*extra_bindings=*/ValueRange());
2732  Type bbArgType = root.getType();
2733  buildSequenceBody(builder, state, bbArgType,
2734  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2735 }
2736 
2737 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2738  TypeRange resultTypes,
2739  FailurePropagationMode failurePropagationMode,
2740  Value root, ValueRange extraBindings,
2741  SequenceBodyBuilderArgsFn bodyBuilder) {
2742  build(builder, state, resultTypes, failurePropagationMode, root,
2743  extraBindings);
2744  buildSequenceBody(builder, state, root.getType(), extraBindings.getTypes(),
2745  bodyBuilder);
2746 }
2747 
2748 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2749  TypeRange resultTypes,
2750  FailurePropagationMode failurePropagationMode,
2751  Type bbArgType,
2752  SequenceBodyBuilderFn bodyBuilder) {
2753  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2754  /*extra_bindings=*/ValueRange());
2755  buildSequenceBody(builder, state, bbArgType,
2756  /*extraBindingTypes=*/TypeRange(), bodyBuilder);
2757 }
2758 
2759 void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
2760  TypeRange resultTypes,
2761  FailurePropagationMode failurePropagationMode,
2762  Type bbArgType, TypeRange extraBindingTypes,
2763  SequenceBodyBuilderArgsFn bodyBuilder) {
2764  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value(),
2765  /*extra_bindings=*/ValueRange());
2766  buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder);
2767 }
2768 
2769 //===----------------------------------------------------------------------===//
2770 // PrintOp
2771 //===----------------------------------------------------------------------===//
2772 
2773 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2774  StringRef name) {
2775  if (!name.empty())
2776  result.getOrAddProperties<Properties>().name = builder.getStringAttr(name);
2777 }
2778 
2779 void transform::PrintOp::build(OpBuilder &builder, OperationState &result,
2780  Value target, StringRef name) {
2781  result.addOperands({target});
2782  build(builder, result, name);
2783 }
2784 
2786 transform::PrintOp::apply(transform::TransformRewriter &rewriter,
2787  transform::TransformResults &results,
2788  transform::TransformState &state) {
2789  llvm::outs() << "[[[ IR printer: ";
2790  if (getName().has_value())
2791  llvm::outs() << *getName() << " ";
2792 
2793  OpPrintingFlags printFlags;
2794  if (getAssumeVerified().value_or(false))
2795  printFlags.assumeVerified();
2796  if (getUseLocalScope().value_or(false))
2797  printFlags.useLocalScope();
2798  if (getSkipRegions().value_or(false))
2799  printFlags.skipRegions();
2800 
2801  if (!getTarget()) {
2802  llvm::outs() << "top-level ]]]\n";
2803  state.getTopLevel()->print(llvm::outs(), printFlags);
2804  llvm::outs() << "\n";
2806  }
2807 
2808  llvm::outs() << "]]]\n";
2809  for (Operation *target : state.getPayloadOps(getTarget())) {
2810  target->print(llvm::outs(), printFlags);
2811  llvm::outs() << "\n";
2812  }
2813 
2815 }
2816 
2817 void transform::PrintOp::getEffects(
2818  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2819  // We don't really care about mutability here, but `getTarget` now
2820  // unconditionally casts to a specific type before verification could run
2821  // here.
2822  if (!getTargetMutable().empty())
2823  onlyReadsHandle(getTargetMutable()[0], effects);
2824  onlyReadsPayload(effects);
2825 
2826  // There is no resource for stderr file descriptor, so just declare print
2827  // writes into the default resource.
2828  effects.emplace_back(MemoryEffects::Write::get());
2829 }
2830 
2831 //===----------------------------------------------------------------------===//
2832 // VerifyOp
2833 //===----------------------------------------------------------------------===//
2834 
2836 transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter,
2837  Operation *target,
2839  transform::TransformState &state) {
2840  if (failed(::mlir::verify(target))) {
2842  << "failed to verify payload op";
2843  diag.attachNote(target->getLoc()) << "payload op";
2844  return diag;
2845  }
2847 }
2848 
2849 void transform::VerifyOp::getEffects(
2850  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2851  transform::onlyReadsHandle(getTargetMutable(), effects);
2852 }
2853 
2854 //===----------------------------------------------------------------------===//
2855 // YieldOp
2856 //===----------------------------------------------------------------------===//
2857 
2858 void transform::YieldOp::getEffects(
2859  SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2860  onlyReadsHandle(getOperandsMutable(), effects);
2861 }
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static DiagnosedSilenceableFailure matchBlock(Block &block, ArrayRef< SmallVector< transform::MappedValue >> blockArgumentMapping, transform::TransformState &state, SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings)
Applies matcher operations from the given block using blockArgumentMapping to initialize block argume...
static void printForeachMatchSymbols(OpAsmPrinter &printer, Operation *op, ArrayAttr matchers, ArrayAttr actions)
Prints the comma-separated list of symbol reference pairs of the format @matcher -> @action.
static DiagnosedSilenceableFailure verifyYieldingSingleBlockOp(FunctionOpInterface op, bool allowExternal)
Verifies that a symbol function-like transform dialect operation has the signature and the terminator...
#define DBGS_MATCHER()
static void buildSequenceBody(OpBuilder &builder, OperationState &state, Type bbArgType, TypeRange extraBindingTypes, FnTy bodyBuilder)
static void forwardEmptyOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
static bool implementSameInterface(Type t1, Type t2)
Returns true if both types implement one of the interfaces provided as template parameters.
static void printSequenceOpOperands(OpAsmPrinter &printer, Operation *op, Value root, Type rootType, ValueRange extraBindings, TypeRange extraBindingTypes)
static bool isValueUsePotentialConsumer(OpOperand &use)
Returns true if the given op operand may be consuming the handle value in the Transform IR.
static ParseResult parseSequenceOpOperands(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &root, Type &rootType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &extraBindings, SmallVectorImpl< Type > &extraBindingTypes)
static DiagnosedSilenceableFailure applySequenceBlock(Block &block, transform::FailurePropagationMode mode, transform::TransformState &state, transform::TransformResults &results)
Applies the transform ops contained in block.
static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op, bool emitWarnings)
Verification of a NamedSequenceOp.
static DiagnosedSilenceableFailure verifyFunctionLikeConsumeAnnotations(FunctionOpInterface op, bool emitWarnings, bool alsoVerifyInternal=false)
Checks that the attributes of the function-like operation have correct consumption effect annotations...
static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, ArrayAttr &matchers, ArrayAttr &actions)
Parses the comma-separated list of symbol reference pairs of the format @matcher -> @action.
#define DEBUG_MATCHER(x)
LogicalResult checkDoubleConsume(Value value, function_ref< InFlightDiagnostic()> reportError)
#define DBGS()
static bool implementSameTransformInterface(Type t1, Type t2)
Returns true if both types implement one of the transform dialect interfaces.
static DiagnosedSilenceableFailure ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, Operation *payload)
Helper function to check if the given transform op is contained in (or equal to) the given payload ta...
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:77
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:132
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:273
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:277
This class describes a specific conversion target.
A compatibility class connecting InFlightDiagnostic to DiagnosedSilenceableFailure while providing an...
The result of a transform IR operation application.
LogicalResult silence()
Converts silenceable failure into LogicalResult success without reporting the diagnostic,...
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
Diagnostic & attachNote(std::optional< Location > loc=std::nullopt)
Attaches a note to the last diagnostic.
std::string getMessage() const
Returns the diagnostic message without emitting it.
bool isDefiniteFailure() const
Returns true if this is a definite failure.
LogicalResult checkAndReport()
Converts all kinds of failure into a LogicalResult failure, emitting the diagnostic if necessary.
bool succeeded() const
Returns true if this is a success.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
bool isSilenceableFailure() const
Returns true if this is a silenceable failure.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
A class for computing basic dominance information.
Definition: Dominance.h:136
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
static constexpr int64_t kNoLimit
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
int64_t maxNumRewrites
This specifies the maximum number of rewrites within an iteration.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void decreaseIndent()=0
Decrease indentation.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:323
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
Set of flags used to control the behavior of the various IR print methods (e.g.
OpPrintingFlags & assumeVerified()
Do not verify the operation when using custom operation printers.
Definition: AsmPrinter.cpp:280
OpPrintingFlags & useLocalScope()
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:288
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:274
This is a value defined by a result of an operation.
Definition: Value.h:457
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
type_range getType() const
Definition: ValueRange.cpp:30
type_range getTypes() const
Definition: ValueRange.cpp:26
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getOpResults()
Definition: Operation.h:415
result_range getResults()
Definition: Operation.h:410
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
static const PassInfo * lookup(StringRef passArg)
Returns the pass info for the specified pass class or null if unknown.
The main pass manager and pipeline builder.
Definition: PassManager.h:231
static const PassPipelineInfo * lookup(StringRef pipelineArg)
Returns the pass pipeline info for the specified pass pipeline or null if unknown.
Structure to group information about a passes and pass pipelines (argument to invoke via mlir-opt,...
Definition: PassRegistry.h:49
LogicalResult addToPipeline(OpPassManager &pm, StringRef options, function_ref< LogicalResult(const Twine &)> errorHandler) const
Adds this pass registry entry to the given pass manager.
Definition: PassRegistry.h:55
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Region * getRegionOrNull() const
Returns the region if branching from a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
BlockArgListType getArguments()
Definition: Region.h:81
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Definition: Region.cpp:62
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
This is a "type erased" representation of a registered operation.
MLIRContext * getContext() const
Definition: PatternMatch.h:823
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
size_t size() const
Return the size of this range.
Definition: TypeRange.h:145
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
Definition: Value.h:129
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
A named class for passing around the variadic flag.
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
LogicalResult appendValueMappings(MutableArrayRef< SmallVector< transform::MappedValue >> mappings, ValueRange values, const transform::TransformState &state, bool flatten=true)
Appends the entities associated with the given transform values in state to the pre-existing list of ...
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, Operation *op, Region &region)
Maps the only block argument of the op with PossibleTopLevelTransformOpTrait to either the list of op...
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
void getPotentialTopLevelEffects(Operation *operation, Value root, Block &body, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with side effects implied by PossibleTopLevelTransformOpTrait for the given operati...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform)
Checks whether the transform op consumes the given handle.
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void getConsumedBlockArguments(Block &block, llvm::SmallDenseSet< unsigned > &consumedArguments)
Populates consumedArguments with positions of block arguments that are consumed by the operations in ...
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument, ::mlir::ValueRange)> SequenceBodyBuilderArgsFn
Definition: TransformOps.h:39
bool doesModifyPayload(transform::TransformOpInterface transform)
Checks whether the transform op modifies the payload.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
bool doesReadPayload(transform::TransformOpInterface transform)
Checks whether the transform op reads the payload.
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
::llvm::function_ref< void(::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)> SequenceBodyBuilderFn
A builder function that populates the body of a SequenceOp.
Definition: TransformOps.h:36
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, bool *changed=nullptr)
Eliminate common subexpressions within the given operation.
Definition: CSE.cpp:382
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
size_t moveLoopInvariantCode(ArrayRef< Region * > regions, function_ref< bool(Value, Region *)> isDefinedOutsideRegion, function_ref< bool(Operation *, Region *)> shouldMoveOutOfRegion, function_ref< void(Operation *, Region *)> moveOutOfRegion)
Given a list of regions, perform loop-invariant code motion.
Dialect conversion configuration.
RewriterBase::Listener * listener
An optional listener that is notified about all IR modifications in case dialect conversion succeeds.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
void addOperands(ValueRange newOperands)
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.